From 14062a882154c1de1dbcd69d2bc482cb8cb79b81 Mon Sep 17 00:00:00 2001 From: cheng cheng <472491134@qq.com> Date: Tue, 23 Mar 2021 22:22:41 +0800 Subject: [PATCH] Complete support 1 regst 1 blob (#4474) * half implement of build task graph by 1regst1blob * Complete support 1 regst 1 blob * fix check * Add Lbis in TaskEdge and check valid * reduce NormalForward out regst name prefix * refine hasher and proxy key * fix bug of ProxyKey == * fix bug of collective boxing broadcast task edge lbi Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- .../boxing/b21_sub_task_graph_builder.cpp | 4 +- ...llective_boxing_sub_task_graph_builder.cpp | 31 +- .../naive_b2b_sub_task_graph_builder.cpp | 4 +- .../naive_b2p_sub_task_graph_builder.cpp | 4 +- .../one_to_one_sub_task_graph_builder.cpp | 3 +- .../slice_boxing_sub_task_graph_builder.cpp | 35 +-- .../boxing/sub_task_graph_builder_context.cpp | 66 ----- .../boxing/sub_task_graph_builder_context.h | 16 -- .../core/graph/boxing_identity_task_node.cpp | 4 +- .../core/graph/boxing_identity_task_node.h | 7 +- oneflow/core/graph/boxing_zeros_task_node.cpp | 4 +- oneflow/core/graph/boxing_zeros_task_node.h | 5 +- .../collective_boxing_pack_task_node.cpp | 4 +- .../graph/collective_boxing_pack_task_node.h | 6 +- .../graph/collective_boxing_task_node.cpp | 3 +- .../core/graph/collective_boxing_task_node.h | 7 +- .../collective_boxing_unpack_task_node.cpp | 4 +- .../collective_boxing_unpack_task_node.h | 5 +- oneflow/core/graph/copy_task_node.cpp | 7 +- oneflow/core/graph/copy_task_node.h | 8 +- .../graph/normal_forward_compute_task_node.h | 2 - oneflow/core/graph/slice_boxing_task_node.cpp | 7 +- oneflow/core/graph/slice_boxing_task_node.h | 6 +- oneflow/core/graph/task_graph.cpp | 272 ++++++++---------- oneflow/core/graph/task_graph.h | 65 +++-- oneflow/core/graph/task_node.cpp | 32 ++- oneflow/core/graph/task_node.h | 6 + oneflow/core/graph/transport_task_node.h | 39 +++ .../normal_forward_compute_task_node.cpp | 95 ++---- oneflow/core/job/compiler.cpp | 2 + 30 files changed, 344 insertions(+), 409 deletions(-) create mode 100644 oneflow/core/graph/transport_task_node.h diff --git a/oneflow/core/graph/boxing/b21_sub_task_graph_builder.cpp b/oneflow/core/graph/boxing/b21_sub_task_graph_builder.cpp index 02b4268c7c..1cc3eefcde 100644 --- a/oneflow/core/graph/boxing/b21_sub_task_graph_builder.cpp +++ b/oneflow/core/graph/boxing/b21_sub_task_graph_builder.cpp @@ -31,8 +31,8 @@ Maybe B21SubTskGphBuilder::Build( const int64_t nearest_in_parallel_id = SubTskGphBuilderUtil::FindNearestSrcParallelId( in_parallel_desc, out_parallel_desc, out_parallel_id); TaskNode* nearest_in_node = sorted_in_tasks.at(nearest_in_parallel_id); - TaskNode* proxy = ctx->GetProxyNode(nearest_in_node, nearest_in_node->MemZoneId121(), - out_parallel_desc, out_parallel_id); + TaskNode* proxy = + ctx->task_graph()->GetProxyNode(nearest_in_node, lbi, out_parallel_desc, out_parallel_id); sorted_out_tasks->push_back(proxy); return TRY(BuildSubTskGphBuilderStatus("B21SubTskGphBuilder", "")); } else { diff --git a/oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.cpp b/oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.cpp index b0571100ab..73acbbfa2c 100644 --- a/oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.cpp +++ b/oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.cpp @@ -73,7 +73,7 @@ void NcclInitCollectiveNode(CollectiveBoxingGenericTaskNode* node, CHECK_NOTNULL(stream_index_generator); auto stream_index = stream_index_generator->GenerateNcclStreamIndex(); const int64_t thrd_id = SerializeStreamIdToInt64(StreamId{device_id, stream_index}); - node->Init(machine_id, thrd_id, op_conf); + node->Init(machine_id, thrd_id, lbi, op_conf); } int64_t FindRootParallelId(const ParallelDesc& multi_device, const ParallelDesc& sole_device) { @@ -117,7 +117,7 @@ class NcclCollectiveBoxingAllReduceSubTskGphBuilder final : public SubTskGphBuil auto* collective_node = ctx->task_graph()->NewNode(); NcclInitCollectiveNode(collective_node, in_parallel_desc, i, op_name, lbi, logical_blob_desc, OpType::kOpTypeAllReduce, -1); - Connect(in_node, ctx->task_graph()->NewEdge(), collective_node); + ctx->task_graph()->ConnectWithLbi(in_node, collective_node, lbi); sorted_out_tasks->push_back(collective_node); } return TRY(BuildSubTskGphBuilderStatus("NcclCollectiveBoxingAllReduceSubTskGphBuilder", "")); @@ -154,7 +154,7 @@ class NcclCollectiveBoxingReduceScatterSubTskGphBuilder final : public SubTskGph auto* collective_node = ctx->task_graph()->NewNode(); NcclInitCollectiveNode(collective_node, in_parallel_desc, i, op_name, lbi, logical_blob_desc, OpType::kOpTypeReduceScatter, -1); - Connect(in_node, ctx->task_graph()->NewEdge(), collective_node); + ctx->task_graph()->ConnectWithLbi(in_node, collective_node, lbi); sorted_out_tasks->push_back(collective_node); } return TRY( @@ -190,11 +190,11 @@ class NcclCollectiveBoxingAllGatherSubTskGphBuilder final : public SubTskGphBuil FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) { TaskNode* in_node = sorted_in_tasks.at(i); TaskNode* in_node_proxy = - ctx->GetProxyNode(in_node, in_node->MemZoneId121(), out_parallel_desc, i); + ctx->task_graph()->GetProxyNode(in_node, lbi, out_parallel_desc, i); auto* collective_node = ctx->task_graph()->NewNode(); NcclInitCollectiveNode(collective_node, out_parallel_desc, i, op_name, lbi, logical_blob_desc, OpType::kOpTypeAllGather, -1); - Connect(in_node_proxy, ctx->task_graph()->NewEdge(), collective_node); + ctx->task_graph()->ConnectWithLbi(in_node_proxy, collective_node, lbi); sorted_out_tasks->push_back(collective_node); } return TRY(BuildSubTskGphBuilderStatus("NcclCollectiveBoxingAllGatherSubTskGphBuilder", "")); @@ -232,7 +232,7 @@ class NcclCollectiveBoxingReduceSubTskGphBuilder final : public SubTskGphBuilder auto* collective_node = ctx->task_graph()->NewNode(); NcclInitCollectiveNode(collective_node, in_parallel_desc, i, op_name, lbi, logical_blob_desc, OpType::kOpTypeReduce, root_parallel_id); - Connect(in_node, ctx->task_graph()->NewEdge(), collective_node); + ctx->task_graph()->ConnectWithLbi(in_node, collective_node, lbi); if (i == root_parallel_id) { sorted_out_tasks->push_back(collective_node); } else { @@ -288,12 +288,12 @@ class CollectiveBoxingScatterThenNcclAllGatherSubTskGphBuilder final : public Su slice_node->ConnectToSrcNodeWithSlice(in_node, ctx->task_graph()->NewEdge(), in_slice); // copy to dst gpu TaskNode* slice_node_proxy = - ctx->GetProxyNode(slice_node, slice_node->MemZoneId121(), out_parallel_desc, out_id); + ctx->task_graph()->GetProxyNode(slice_node, lbi, out_parallel_desc, out_id); // allgather auto* collective_node = ctx->task_graph()->NewNode(); NcclInitCollectiveNode(collective_node, out_parallel_desc, out_id, op_name, lbi, logical_blob_desc, OpType::kOpTypeAllGather, -1); - Connect(slice_node_proxy, ctx->task_graph()->NewEdge(), collective_node); + ctx->task_graph()->ConnectWithLbi(slice_node_proxy, collective_node, lbi); sorted_out_tasks->push_back(collective_node); } return TRY(BuildSubTskGphBuilderStatus( @@ -330,8 +330,8 @@ class NcclCollectiveBoxingBroadcastSubTskGphBuilder final : public SubTskGphBuil auto* cpu_in_node = sorted_in_tasks.front(); root_parallel_id = SubTskGphBuilderUtil::FindNearestSrcParallelId(out_parallel_desc, in_parallel_desc, 0); - gpu_in_node = ctx->GetProxyNode(cpu_in_node, cpu_in_node->MemZoneId121(), out_parallel_desc, - root_parallel_id); + gpu_in_node = + ctx->task_graph()->GetProxyNode(cpu_in_node, lbi, out_parallel_desc, root_parallel_id); } else if (in_parallel_desc.device_type() == DeviceType::kGPU) { root_parallel_id = FindRootParallelId(out_parallel_desc, in_parallel_desc); @@ -347,10 +347,11 @@ class NcclCollectiveBoxingBroadcastSubTskGphBuilder final : public SubTskGphBuil NcclInitCollectiveNode(collective_node, out_parallel_desc, i, op_name, lbi, logical_blob_desc, OpType::kOpTypeBroadcast, root_parallel_id); if (i == root_parallel_id) { - Connect(gpu_in_node, ctx->task_graph()->NewEdge(), collective_node); + ctx->task_graph()->ConnectWithLbi(gpu_in_node, collective_node, lbi); } else { gpu_in_node->BuildCtrlRegstDesc(collective_node); - Connect(gpu_in_node, ctx->task_graph()->NewEdge(), collective_node); + Connect(gpu_in_node, ctx->task_graph()->NewTaskEdgeWithLbi(lbi), + collective_node); } sorted_out_tasks->push_back(collective_node); } @@ -402,18 +403,18 @@ class NcclCollectiveBoxingAll2AllSubTskGphBuilder final : public SubTskGphBuilde ctx->task_graph()->NewNode(); pack_node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(), in_sbp_parallel, out_sbp_parallel, in_parallel_desc.parallel_num()); - Connect(in_node, ctx->task_graph()->NewEdge(), pack_node); + ctx->task_graph()->ConnectWithLbi(in_node, pack_node, lbi); auto* collective_node = ctx->task_graph()->NewNode(); NcclInitCollectiveNode(collective_node, out_parallel_desc, i, op_name, lbi, logical_blob_desc, OpType::kOpTypeAll2All, -1); - Connect(pack_node, ctx->task_graph()->NewEdge(), collective_node); + ctx->task_graph()->ConnectWithLbi(pack_node, collective_node, lbi); CollectiveBoxingUnpackTaskNode* unpack_node = ctx->task_graph()->NewNode(); unpack_node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(), in_sbp_parallel, out_sbp_parallel, in_parallel_desc.parallel_num()); - Connect(collective_node, ctx->task_graph()->NewEdge(), unpack_node); + ctx->task_graph()->ConnectWithLbi(collective_node, unpack_node, lbi); sorted_out_tasks->push_back(unpack_node); } return TRY(BuildSubTskGphBuilderStatus("NcclCollectiveBoxingAll2AllSubTskGphBuilder", "")); diff --git a/oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.cpp b/oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.cpp index a483ae9b22..25847a6095 100644 --- a/oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.cpp +++ b/oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.cpp @@ -31,8 +31,8 @@ Maybe NaiveB2BSubTskGphBuilder::Build( const int64_t nearest_in_parallel_id = SubTskGphBuilderUtil::FindNearestSrcParallelId( in_parallel_desc, out_parallel_desc, out_id); TaskNode* nearest_in_node = sorted_in_tasks.at(nearest_in_parallel_id); - TaskNode* proxy = ctx->GetProxyNode(nearest_in_node, nearest_in_node->MemZoneId121(), - out_parallel_desc, out_id); + TaskNode* proxy = + ctx->task_graph()->GetProxyNode(nearest_in_node, lbi, out_parallel_desc, out_id); sorted_out_tasks->push_back(proxy); } return TRY(BuildSubTskGphBuilderStatus("NaiveB2BSubTskGphBuilder", "")); diff --git a/oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.cpp b/oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.cpp index 0c51b13532..6c8b1514c7 100644 --- a/oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.cpp +++ b/oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.cpp @@ -50,8 +50,8 @@ Maybe NaiveB2PSubTskGphBuilder::Build( const int64_t nearest_in_id = out_id2nearest_in_id.at(out_id); TaskNode* nearest_in_node = sorted_in_tasks.at(nearest_in_id); if (out_id == nearest_out_node_idx) { - TaskNode* proxy = ctx->GetProxyNode(nearest_in_node, nearest_in_node->MemZoneId121(), - out_parallel_desc, out_id); + TaskNode* proxy = + ctx->task_graph()->GetProxyNode(nearest_in_node, lbi, out_parallel_desc, out_id); sorted_out_tasks->push_back(proxy); } else { diff --git a/oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.cpp b/oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.cpp index f40a71cf05..11939a8beb 100644 --- a/oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.cpp +++ b/oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.cpp @@ -30,8 +30,7 @@ Maybe OneToOneSubTskGphBuilder::Build( && in_sbp_parallel == out_sbp_parallel)) { for (int64_t i = 0; i < in_parallel_desc.parallel_num(); ++i) { TaskNode* in_node = sorted_in_tasks.at(i); - // TODO(liujuncheng): use lbi - TaskNode* proxy = ctx->GetProxyNode(in_node, in_node->MemZoneId121(), out_parallel_desc, i); + TaskNode* proxy = ctx->task_graph()->GetProxyNode(in_node, lbi, out_parallel_desc, i); sorted_out_tasks->push_back(proxy); } return TRY(BuildSubTskGphBuilderStatus("OneToOneSubTskGphBuilder", "")); diff --git a/oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.cpp b/oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.cpp index 30daf94b8a..a41b08b458 100644 --- a/oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.cpp +++ b/oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.cpp @@ -163,7 +163,7 @@ Maybe SliceBoxingSubTskGphBuilder::Build( dst_node->ConnectToSrcNodeWithSlice(src_node, NewEdge(), src_slice); return dst_node; }; - const auto BuildSubTaskGphS2B = [&ctx, &CreateBoxingNode121, &NewEdge]( + const auto BuildSubTaskGphS2B = [&ctx, &CreateBoxingNode121, &NewEdge, &lbi]( const ParallelDesc& in_pd, const ParallelDesc& out_pd, const SbpParallel& in_sbp, const SbpParallel& out_sbp, const BlobDesc& blob_desc, @@ -183,9 +183,8 @@ Maybe SliceBoxingSubTskGphBuilder::Build( if (SubTskGphBuilderUtil::IsOnSameGPU(in_node, out_node)) { out_node->ConnectToSrcNodeWithSlice(in_node, NewEdge(), in_slice); } else { - TaskNode* proxy_node = - ctx->GetProxyNode(in_node, in_node->MemZoneId121(), out_node->machine_id(), - Global::Get()->CpuMemZoneId()); + TaskNode* proxy_node = ctx->task_graph()->GetProxyNode( + in_node, lbi, out_node->machine_id(), Global::Get()->CpuMemZoneId()); out_node->ConnectToSrcNodeWithSlice(proxy_node, NewEdge(), in_slice); } } @@ -291,9 +290,8 @@ Maybe SliceBoxingSubTskGphBuilder::Build( in_slices.at(in_id)); } } - TaskNode* local_add_proxy_node = - ctx->GetProxyNode(local_concat_node, Global::Get()->CpuMemZoneId(), - out_node->machine_id(), Global::Get()->CpuMemZoneId()); + TaskNode* local_add_proxy_node = ctx->task_graph()->GetProxyNode( + local_concat_node, lbi, out_node->machine_id(), Global::Get()->CpuMemZoneId()); out_node->ConnectToSrcNodeWithSlice(local_add_proxy_node, NewEdge(), concat_slice); } } @@ -353,9 +351,8 @@ Maybe SliceBoxingSubTskGphBuilder::Build( for (const int64_t in_id : in_parallel_ids) { local_add_node->ConnectToSrcNodeWithSlice(in_nodes.at(in_id), NewEdge(), in_slice); } - TaskNode* local_add_proxy_node = - ctx->GetProxyNode(local_add_node, Global::Get()->CpuMemZoneId(), - out_node->machine_id(), Global::Get()->CpuMemZoneId()); + TaskNode* local_add_proxy_node = ctx->task_graph()->GetProxyNode( + local_add_node, lbi, out_node->machine_id(), Global::Get()->CpuMemZoneId()); out_node->ConnectToSrcNodeWithSlice(local_add_proxy_node, NewEdge(), out_slice); } } @@ -409,18 +406,17 @@ Maybe SliceBoxingSubTskGphBuilder::Build( const int64_t out_machine_id = machine_id7out_parallel_ids.first; TaskNode* in_box_node = nullptr; if (out_box_nodes.size() == 1) { - in_box_node = ctx->GetProxyNode( - out_box_nodes.front(), out_box_nodes.front()->MemZoneId121(), - machine_id7out_parallel_ids.first, Global::Get()->CpuMemZoneId()); + in_box_node = ctx->task_graph()->GetProxyNode(out_box_nodes.front(), lbi, + machine_id7out_parallel_ids.first, + Global::Get()->CpuMemZoneId()); } else { auto* add_node = ctx->task_graph()->NewNode(); add_node->Init(lbi, slice, kSliceBoxingTaskModeAdd, machine_id7out_parallel_ids.first, Global::Get()->PickCpuThrdIdEvenly(machine_id7out_parallel_ids.first), Global::Get()->CpuMemZoneId()); for (TaskNode* out_box_node : out_box_nodes) { - TaskNode* out_boxing_node_proxy = - ctx->GetProxyNode(out_box_node, out_box_node->MemZoneId121(), out_machine_id, - Global::Get()->CpuMemZoneId()); + TaskNode* out_boxing_node_proxy = ctx->task_graph()->GetProxyNode( + out_box_node, lbi, out_machine_id, Global::Get()->CpuMemZoneId()); add_node->ConnectToSrcNodeWithSlice(out_boxing_node_proxy, NewEdge(), slice); } in_box_node = add_node; @@ -435,8 +431,8 @@ Maybe SliceBoxingSubTskGphBuilder::Build( } else { UNIMPLEMENTED(); } - (*out_nodes)[out_id] = ctx->GetProxyNode(in_box_node, Global::Get()->CpuMemZoneId(), - out_machine_id, mem_zone_id); + (*out_nodes)[out_id] = + ctx->task_graph()->GetProxyNode(in_box_node, lbi, out_machine_id, mem_zone_id); } } }; @@ -460,8 +456,7 @@ Maybe SliceBoxingSubTskGphBuilder::Build( SliceBoxingTaskNode* slice_node = CreateBoxingNode121(in_pd, nearest_idx, out_slice, kSliceBoxingTaskModeCopy); slice_node->ConnectToSrcNodeWithSlice(in_node, NewEdge(), in_slice); - TaskNode* out_node = - ctx->GetProxyNode(slice_node, slice_node->MemZoneId121(), out_pd, out_id); + TaskNode* out_node = ctx->task_graph()->GetProxyNode(slice_node, lbi, out_pd, out_id); out_nodes->push_back(out_node); } diff --git a/oneflow/core/graph/boxing/sub_task_graph_builder_context.cpp b/oneflow/core/graph/boxing/sub_task_graph_builder_context.cpp index 69aff5812a..2e76d302ca 100644 --- a/oneflow/core/graph/boxing/sub_task_graph_builder_context.cpp +++ b/oneflow/core/graph/boxing/sub_task_graph_builder_context.cpp @@ -21,70 +21,4 @@ SubTskGphBuilderCtx::SubTskGphBuilderCtx(TaskGraph* task_graph) : task_graph_(ta TaskGraph* SubTskGphBuilderCtx::task_graph() { return task_graph_; } -TaskNode* SubTskGphBuilderCtx::GetProxyNode(TaskNode* src_node, int64_t src_mem_zone_id, - int64_t dst_machine_id, int64_t dst_mem_zone_id) { - const auto key = std::make_pair(dst_machine_id, dst_mem_zone_id); - if (node2proxies_.find(src_node) != node2proxies_.cend() - && node2proxies_.at(src_node).find(key) != node2proxies_.at(src_node).cend()) { - return node2proxies_.at(src_node).at(key); - } else { - if (dst_machine_id == src_node->machine_id() && dst_mem_zone_id == src_mem_zone_id) { - node2proxies_[src_node][key] = src_node; - return src_node; - } else if (Global::Get()->IsGpuMemZone(dst_mem_zone_id)) { - TaskNode* proxy_on_dst_host = GetProxyNode(src_node, src_mem_zone_id, dst_machine_id, - Global::Get()->CpuMemZoneId()); - CopyHdTaskNode* copy_task = task_graph()->NewNode(); - copy_task->Init(CopyHdOpConf::H2D, proxy_on_dst_host->machine_id(), - Global::Get()->GetGpuPhyIdFromMemZoneId(dst_mem_zone_id)); - Connect(proxy_on_dst_host, task_graph()->NewEdge(), copy_task); - node2proxies_[src_node][key] = copy_task; - return copy_task; - } else if (Global::Get()->IsCpuMemZone(dst_mem_zone_id)) { - if (src_node->machine_id() == dst_machine_id) { - if (Global::Get()->IsGpuMemZone(src_mem_zone_id)) { - CopyHdTaskNode* copy_task = task_graph()->NewNode(); - copy_task->Init(CopyHdOpConf::D2H, src_node->machine_id(), - Global::Get()->GetGpuPhyIdFromMemZoneId(src_mem_zone_id)); - Connect(src_node, task_graph()->NewEdge(), copy_task); - node2proxies_[src_node][key] = copy_task; - return copy_task; - } else { - UNIMPLEMENTED(); - } - } else { - TaskNode* proxy_on_src_host = - GetProxyNode(src_node, src_mem_zone_id, src_node->machine_id(), - Global::Get()->CpuMemZoneId()); - CopyCommNetTaskNode* copy_comm_net_task = task_graph()->NewNode(); - copy_comm_net_task->Init(dst_machine_id); - Connect(proxy_on_src_host, task_graph()->NewEdge(), copy_comm_net_task); - node2proxies_[src_node][key] = copy_comm_net_task; - return copy_comm_net_task; - } - } else { - UNIMPLEMENTED(); - } - } -} - -TaskNode* SubTskGphBuilderCtx::GetProxyNode(TaskNode* src_node, const int64_t src_mem_zone_id, - const ParallelDesc& dst_parallel_desc, - const int64_t dst_parallel_id) { - const int64_t dst_machine_id = - CHECK_JUST(dst_parallel_desc.MachineId4ParallelId(dst_parallel_id)); - int64_t dst_mem_zone_id; - const IDMgr* id_mgr = Global::Get(); - if (dst_parallel_desc.device_type() == DeviceType::kCPU) { - dst_mem_zone_id = id_mgr->CpuMemZoneId(); - } else if (dst_parallel_desc.device_type() == DeviceType::kGPU) { - const int64_t dst_dev_phy_id = - CHECK_JUST(dst_parallel_desc.DeviceId4ParallelId(dst_parallel_id)); - dst_mem_zone_id = id_mgr->GpuMemZoneId(dst_dev_phy_id); - } else { - UNIMPLEMENTED(); - } - return GetProxyNode(src_node, src_mem_zone_id, dst_machine_id, dst_mem_zone_id); -} - } // namespace oneflow diff --git a/oneflow/core/graph/boxing/sub_task_graph_builder_context.h b/oneflow/core/graph/boxing/sub_task_graph_builder_context.h index 62056c2a08..749c330a47 100644 --- a/oneflow/core/graph/boxing/sub_task_graph_builder_context.h +++ b/oneflow/core/graph/boxing/sub_task_graph_builder_context.h @@ -18,13 +18,9 @@ limitations under the License. #include "oneflow/core/common/util.h" #include "oneflow/core/graph/task_graph.h" -#include "oneflow/core/memory/memory_allocator.h" namespace oneflow { -class TaskGraph; -class TaskNode; - class SubTskGphBuilderCtx final { public: OF_DISALLOW_COPY_AND_MOVE(SubTskGphBuilderCtx); @@ -32,21 +28,9 @@ class SubTskGphBuilderCtx final { virtual ~SubTskGphBuilderCtx() = default; virtual TaskGraph* task_graph(); - TaskNode* GetProxyNode(TaskNode* src_node, int64_t src_mem_zone_id, int64_t dst_machine_id, - int64_t dst_mem_zone_id); - TaskNode* GetProxyNode(TaskNode* src_node, int64_t src_mem_zone_id, - const ParallelDesc& dst_parallel_desc, const int64_t dst_parallel_id); - template - void ConnectAll121(const std::vector& src_nodes, const std::vector& dst_nodes) { - CHECK_EQ(src_nodes.size(), dst_nodes.size()); - FOR_RANGE(int64_t, i, 0, dst_nodes.size()) { - Connect(src_nodes.at(i), task_graph()->NewEdge(), dst_nodes.at(i)); - } - } private: TaskGraph* task_graph_; - HashMap, TaskNode*>> node2proxies_; }; } // namespace oneflow diff --git a/oneflow/core/graph/boxing_identity_task_node.cpp b/oneflow/core/graph/boxing_identity_task_node.cpp index 2ad5b4f64b..6646f6682c 100644 --- a/oneflow/core/graph/boxing_identity_task_node.cpp +++ b/oneflow/core/graph/boxing_identity_task_node.cpp @@ -19,9 +19,9 @@ limitations under the License. namespace oneflow { void BoxingIdentityTaskNode::Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi) { - lbi_ = lbi; set_machine_id(machine_id); set_thrd_id(thrd_id); + set_lbi(lbi); } void BoxingIdentityTaskNode::ProduceAllRegstsAndBindEdges() { @@ -39,7 +39,7 @@ void BoxingIdentityTaskNode::BuildExecGphAndRegst() { OperatorConf op_conf; op_conf.set_name("System-Boxing-Identity-" + NewUniqueId()); op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type()))); - *op_conf.mutable_boxing_identity_conf()->mutable_lbi() = lbi_; + *op_conf.mutable_boxing_identity_conf()->mutable_lbi() = lbi(); std::shared_ptr sole_op = ConstructOp(op_conf); node->mut_op() = sole_op; node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst("in")); diff --git a/oneflow/core/graph/boxing_identity_task_node.h b/oneflow/core/graph/boxing_identity_task_node.h index a99cf2e110..4a1935a71c 100644 --- a/oneflow/core/graph/boxing_identity_task_node.h +++ b/oneflow/core/graph/boxing_identity_task_node.h @@ -15,11 +15,12 @@ limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_BOXING_IDENTITY_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_BOXING_IDENTITY_TASK_NODE_H_ -#include "oneflow/core/graph/task_node.h" + +#include "oneflow/core/graph/transport_task_node.h" namespace oneflow { -class BoxingIdentityTaskNode : public TaskNode { +class BoxingIdentityTaskNode : public TransportTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(BoxingIdentityTaskNode); BoxingIdentityTaskNode() = default; @@ -33,8 +34,6 @@ class BoxingIdentityTaskNode : public TaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() final; void InferProducedDataRegstTimeShape() final; - - LogicalBlobId lbi_; }; } // namespace oneflow diff --git a/oneflow/core/graph/boxing_zeros_task_node.cpp b/oneflow/core/graph/boxing_zeros_task_node.cpp index bb574f06f3..87d25fbdf1 100644 --- a/oneflow/core/graph/boxing_zeros_task_node.cpp +++ b/oneflow/core/graph/boxing_zeros_task_node.cpp @@ -20,9 +20,9 @@ namespace oneflow { void BoxingZerosTaskNode::Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi, const Shape& shape, DataType data_type, const Shape& time_shape) { - lbi_ = lbi; set_machine_id(machine_id); set_thrd_id(thrd_id); + set_lbi(lbi); shape_ = shape; data_type_ = data_type; time_shape_ = time_shape; @@ -42,7 +42,7 @@ void BoxingZerosTaskNode::BuildExecGphAndRegst() { OperatorConf op_conf; op_conf.set_name("System-Boxing-Zeros-" + NewUniqueId()); op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type()))); - *op_conf.mutable_boxing_zeros_conf()->mutable_lbi() = lbi_; + *op_conf.mutable_boxing_zeros_conf()->mutable_lbi() = lbi(); shape_.ToProto(op_conf.mutable_boxing_zeros_conf()->mutable_shape()); op_conf.mutable_boxing_zeros_conf()->set_data_type(data_type_); std::shared_ptr sole_op = ConstructOp(op_conf); diff --git a/oneflow/core/graph/boxing_zeros_task_node.h b/oneflow/core/graph/boxing_zeros_task_node.h index b47c2efc5f..5a99df1d5e 100644 --- a/oneflow/core/graph/boxing_zeros_task_node.h +++ b/oneflow/core/graph/boxing_zeros_task_node.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef ONEFLOW_CORE_GRAPH_BOXING_ZEROS_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_BOXING_ZEROS_TASK_NODE_H_ -#include "oneflow/core/graph/compute_task_node.h" +#include "oneflow/core/graph/transport_task_node.h" namespace oneflow { -class BoxingZerosTaskNode : public TaskNode { +class BoxingZerosTaskNode : public TransportTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(BoxingZerosTaskNode); BoxingZerosTaskNode() = default; @@ -36,7 +36,6 @@ class BoxingZerosTaskNode : public TaskNode { void ConsumeAllRegsts() final; void InferProducedDataRegstTimeShape() final; - LogicalBlobId lbi_; Shape shape_; DataType data_type_; Shape time_shape_; diff --git a/oneflow/core/graph/collective_boxing_pack_task_node.cpp b/oneflow/core/graph/collective_boxing_pack_task_node.cpp index f62be0f354..1b4ce333f0 100644 --- a/oneflow/core/graph/collective_boxing_pack_task_node.cpp +++ b/oneflow/core/graph/collective_boxing_pack_task_node.cpp @@ -23,9 +23,9 @@ void CollectiveBoxingPackTaskNode::Init(int64_t machine_id, int64_t thrd_id, const SbpParallel& src_sbp_parallel, const SbpParallel& dst_sbp_parallel, const int64_t parallel_num) { - lbi_ = lbi; set_machine_id(machine_id); set_thrd_id(thrd_id); + set_lbi(lbi); logical_shape_ = logical_shape; parallel_num_ = parallel_num; src_sbp_parallel_ = src_sbp_parallel; @@ -48,7 +48,7 @@ void CollectiveBoxingPackTaskNode::BuildExecGphAndRegst() { op_conf.set_name("System-Collective-Boxing-Pack-" + NewUniqueId()); op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type()))); auto* collective_boxing_pack_conf = op_conf.mutable_collective_boxing_pack_conf(); - *collective_boxing_pack_conf->mutable_lbi() = lbi_; + *collective_boxing_pack_conf->mutable_lbi() = lbi(); logical_shape_.ToProto(collective_boxing_pack_conf->mutable_logical_shape()); *collective_boxing_pack_conf->mutable_src_sbp_parallel() = src_sbp_parallel_; *collective_boxing_pack_conf->mutable_dst_sbp_parallel() = dst_sbp_parallel_; diff --git a/oneflow/core/graph/collective_boxing_pack_task_node.h b/oneflow/core/graph/collective_boxing_pack_task_node.h index 3505f1142c..9230019e4f 100644 --- a/oneflow/core/graph/collective_boxing_pack_task_node.h +++ b/oneflow/core/graph/collective_boxing_pack_task_node.h @@ -15,11 +15,12 @@ limitations under the License. */ #ifndef ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_PACK_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_PACK_TASK_NODE_H_ -#include "oneflow/core/graph/task_node.h" + +#include "oneflow/core/graph/transport_task_node.h" namespace oneflow { -class CollectiveBoxingPackTaskNode : public TaskNode { +class CollectiveBoxingPackTaskNode : public TransportTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingPackTaskNode); CollectiveBoxingPackTaskNode() = default; @@ -36,7 +37,6 @@ class CollectiveBoxingPackTaskNode : public TaskNode { void ConsumeAllRegsts() final; void InferProducedDataRegstTimeShape() final; - LogicalBlobId lbi_; Shape logical_shape_; SbpParallel src_sbp_parallel_; SbpParallel dst_sbp_parallel_; diff --git a/oneflow/core/graph/collective_boxing_task_node.cpp b/oneflow/core/graph/collective_boxing_task_node.cpp index 988b31b91a..2e3f2ad101 100644 --- a/oneflow/core/graph/collective_boxing_task_node.cpp +++ b/oneflow/core/graph/collective_boxing_task_node.cpp @@ -19,9 +19,10 @@ limitations under the License. namespace oneflow { void CollectiveBoxingGenericTaskNode::Init(int64_t machine_id, int64_t thrd_id, - const OperatorConf& op_conf) { + const LogicalBlobId& lbi, const OperatorConf& op_conf) { set_machine_id(machine_id); set_thrd_id(thrd_id); + set_lbi(lbi); op_conf_ = op_conf; } diff --git a/oneflow/core/graph/collective_boxing_task_node.h b/oneflow/core/graph/collective_boxing_task_node.h index a79cdf1fbb..12ee316f22 100644 --- a/oneflow/core/graph/collective_boxing_task_node.h +++ b/oneflow/core/graph/collective_boxing_task_node.h @@ -16,17 +16,18 @@ limitations under the License. #ifndef ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_TASK_NODE_H_ -#include "oneflow/core/graph/task_node.h" +#include "oneflow/core/graph/transport_task_node.h" namespace oneflow { -class CollectiveBoxingGenericTaskNode : public TaskNode { +class CollectiveBoxingGenericTaskNode : public TransportTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingGenericTaskNode); CollectiveBoxingGenericTaskNode() = default; ~CollectiveBoxingGenericTaskNode() override = default; - void Init(int64_t machine_id, int64_t thrd_id, const OperatorConf& op_conf); + void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi, + const OperatorConf& op_conf); private: void BuildExecGphAndRegst() override; diff --git a/oneflow/core/graph/collective_boxing_unpack_task_node.cpp b/oneflow/core/graph/collective_boxing_unpack_task_node.cpp index 8dadc215b7..4ced9293a6 100644 --- a/oneflow/core/graph/collective_boxing_unpack_task_node.cpp +++ b/oneflow/core/graph/collective_boxing_unpack_task_node.cpp @@ -23,9 +23,9 @@ void CollectiveBoxingUnpackTaskNode::Init(int64_t machine_id, int64_t thrd_id, const SbpParallel& src_sbp_parallel, const SbpParallel& dst_sbp_parallel, const int64_t parallel_num) { - lbi_ = lbi; set_machine_id(machine_id); set_thrd_id(thrd_id); + set_lbi(lbi); logical_shape_ = logical_shape; parallel_num_ = parallel_num; src_sbp_parallel_ = src_sbp_parallel; @@ -48,7 +48,7 @@ void CollectiveBoxingUnpackTaskNode::BuildExecGphAndRegst() { op_conf.set_name("System-Collective-Boxing-Unpack-" + NewUniqueId()); op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type()))); auto* collective_boxing_unpack_conf = op_conf.mutable_collective_boxing_unpack_conf(); - *collective_boxing_unpack_conf->mutable_lbi() = lbi_; + *collective_boxing_unpack_conf->mutable_lbi() = lbi(); logical_shape_.ToProto(collective_boxing_unpack_conf->mutable_logical_shape()); *collective_boxing_unpack_conf->mutable_src_sbp_parallel() = src_sbp_parallel_; *collective_boxing_unpack_conf->mutable_dst_sbp_parallel() = dst_sbp_parallel_; diff --git a/oneflow/core/graph/collective_boxing_unpack_task_node.h b/oneflow/core/graph/collective_boxing_unpack_task_node.h index 940570ff04..4d74b6952c 100644 --- a/oneflow/core/graph/collective_boxing_unpack_task_node.h +++ b/oneflow/core/graph/collective_boxing_unpack_task_node.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_UNPACK_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_UNPACK_TASK_NODE_H_ -#include "oneflow/core/graph/task_node.h" +#include "oneflow/core/graph/transport_task_node.h" namespace oneflow { -class CollectiveBoxingUnpackTaskNode : public TaskNode { +class CollectiveBoxingUnpackTaskNode : public TransportTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingUnpackTaskNode); CollectiveBoxingUnpackTaskNode() = default; @@ -38,7 +38,6 @@ class CollectiveBoxingUnpackTaskNode : public TaskNode { void ConsumeAllRegsts() final; void InferProducedDataRegstTimeShape() final; - LogicalBlobId lbi_; Shape logical_shape_; SbpParallel src_sbp_parallel_; SbpParallel dst_sbp_parallel_; diff --git a/oneflow/core/graph/copy_task_node.cpp b/oneflow/core/graph/copy_task_node.cpp index c5438688c6..465066bd11 100644 --- a/oneflow/core/graph/copy_task_node.cpp +++ b/oneflow/core/graph/copy_task_node.cpp @@ -57,7 +57,8 @@ void CopyTaskNode::BuildExecGphAndRegst() { void CopyTaskNode::InferProducedDataRegstTimeShape() { NaiveInferProducedDataRegstTimeShape(); } -void CopyHdTaskNode::Init(CopyHdOpConf::Type copy_type, int64_t machine_id, int64_t dev_phy_id) { +void CopyHdTaskNode::Init(CopyHdOpConf::Type copy_type, int64_t machine_id, int64_t dev_phy_id, + const LogicalBlobId& lbi) { copy_type_ = copy_type; set_machine_id(machine_id); DeviceId device_id{static_cast(machine_id), DeviceType::kGPU, @@ -73,6 +74,7 @@ void CopyHdTaskNode::Init(CopyHdOpConf::Type copy_type, int64_t machine_id, int6 UNIMPLEMENTED(); } set_thrd_id(SerializeStreamIdToInt64(StreamId{device_id, stream_index})); + set_lbi(lbi); } void CopyHdTaskNode::InitProducedRegstMemCase(MemoryCase* mem_case) { @@ -98,7 +100,7 @@ OperatorConf CopyHdTaskNode::NewCopyOpConf() { return conf; } -void CopyCommNetTaskNode::Init(int64_t machine_id) { +void CopyCommNetTaskNode::Init(int64_t machine_id, const LogicalBlobId& lbi) { set_machine_id(machine_id); DeviceId device_id{static_cast(machine_id), DeviceType::kCPU, DeviceId::kCPUDeviceIndex}; @@ -107,6 +109,7 @@ void CopyCommNetTaskNode::Init(int64_t machine_id) { CHECK_NOTNULL(generator); StreamId stream_id{device_id, generator->GenerateCommNetStreamIndex()}; set_thrd_id(SerializeStreamIdToInt64(stream_id)); + set_lbi(lbi); } void CopyCommNetTaskNode::InitProducedRegstMemCase(MemoryCase* mem_case) { diff --git a/oneflow/core/graph/copy_task_node.h b/oneflow/core/graph/copy_task_node.h index 6a09f3331d..2b0f5ac71b 100644 --- a/oneflow/core/graph/copy_task_node.h +++ b/oneflow/core/graph/copy_task_node.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef ONEFLOW_CORE_GRAPH_COPY_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_COPY_TASK_NODE_H_ -#include "oneflow/core/graph/task_node.h" +#include "oneflow/core/graph/transport_task_node.h" namespace oneflow { -class CopyTaskNode : public TaskNode { +class CopyTaskNode : public TransportTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(CopyTaskNode); CopyTaskNode() = default; @@ -45,7 +45,7 @@ class CopyHdTaskNode final : public CopyTaskNode { TaskType GetTaskType() const override { return TaskType::kCopyHd; } - void Init(CopyHdOpConf::Type, int64_t machine_id, int64_t dev_phy_id); + void Init(CopyHdOpConf::Type, int64_t machine_id, int64_t dev_phy_id, const LogicalBlobId& lbi); CopyHdOpConf::Type copy_type() const { return copy_type_; } int64_t MemZoneId121() const override { @@ -74,7 +74,7 @@ class CopyCommNetTaskNode final : public CopyTaskNode { TaskType GetTaskType() const override { return TaskType::kCopyCommNet; } - void Init(int64_t machine_id); + void Init(int64_t machine_id, const LogicalBlobId& lbi); private: void InitProducedRegstMemCase(MemoryCase*) override; diff --git a/oneflow/core/graph/normal_forward_compute_task_node.h b/oneflow/core/graph/normal_forward_compute_task_node.h index f375f80707..c0ac525dfa 100644 --- a/oneflow/core/graph/normal_forward_compute_task_node.h +++ b/oneflow/core/graph/normal_forward_compute_task_node.h @@ -34,8 +34,6 @@ class NormalForwardCompTaskNode final : public CompTaskNode { bool HasBackwardCompTaskNode(); private: - bool IsAllOutNodeNormalForward() const; - bool CanProduceSeperatedRegstsForEachOutBlob() const; void ProduceOutRegstByNameAndBlockNum(const std::string& name, size_t mem_block_num); void BuildExecGphAndRegst() override; void BuildExecGphStructAndBindInRegst(); diff --git a/oneflow/core/graph/slice_boxing_task_node.cpp b/oneflow/core/graph/slice_boxing_task_node.cpp index 0b620e2c8c..89ddc7b17f 100644 --- a/oneflow/core/graph/slice_boxing_task_node.cpp +++ b/oneflow/core/graph/slice_boxing_task_node.cpp @@ -21,13 +21,13 @@ namespace oneflow { void SliceBoxingTaskNode::Init(const LogicalBlobId& lbi, const TensorSliceView& out_slice, const SliceBoxingTaskMode mode, int64_t machine_id, int64_t thrd_id, int64_t mem_zone_id) { - lbi_ = lbi; out_slice_ = out_slice; out_shape_ = out_slice.shape(); mode_ = mode; mem_zone_id_ = mem_zone_id; set_machine_id(machine_id); set_thrd_id(thrd_id); + set_lbi(lbi); } void SliceBoxingTaskNode::Init(const LogicalBlobId& lbi, const TensorSliceView& out_slice, @@ -77,7 +77,7 @@ void SliceBoxingTaskNode::BuildExecGphAndRegst() { node->BindBnWithRegst(ibn, GetSoleConsumedRegst("in_" + std::to_string(i))); } std::shared_ptr out_regst = GetProducedRegst("out"); - out_regst->AddLbi(lbi_); + out_regst->AddLbi(lbi()); node->BindBnWithRegst(op->SoleObn(), out_regst); node->AddBnToRegstAndBindIt(&Operator::tmp_bns, GetProducedRegst("tmp")); node->InferBlobDescs(parallel_ctx()); @@ -94,6 +94,7 @@ void SliceBoxingTaskNode::SetInDataEdgeSlice(const TaskEdge* edge, const TensorS void SliceBoxingTaskNode::ConnectToSrcNodeWithSlice(TaskNode* src, TaskEdge* edge, const TensorSliceView& slice) { + edge->AddLbi(lbi()); Connect(src, edge, this); SetInDataEdgeSlice(edge, slice); } @@ -104,7 +105,7 @@ OperatorConf SliceBoxingTaskNode::GetBoxingOpConf() { OperatorConf op_conf{}; op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(device_type()))); SliceBoxingConf boxing_conf{}; - *boxing_conf.mutable_lbi() = lbi_; + *boxing_conf.mutable_lbi() = lbi(); out_slice_.ToProto(boxing_conf.mutable_out_slice()); out_shape_.ToProto(boxing_conf.mutable_out_shape()); for (const TaskEdge* edge : ordered_in_data_edges_) { diff --git a/oneflow/core/graph/slice_boxing_task_node.h b/oneflow/core/graph/slice_boxing_task_node.h index caaddc391d..ff72380529 100644 --- a/oneflow/core/graph/slice_boxing_task_node.h +++ b/oneflow/core/graph/slice_boxing_task_node.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef ONEFLOW_CORE_GRAPH_SLICE_BOXING_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_SLICE_BOXING_TASK_NODE_H_ -#include "oneflow/core/graph/task_node.h" +#include "oneflow/core/graph/transport_task_node.h" #include "oneflow/core/register/tensor_slice_view.h" namespace oneflow { @@ -27,7 +27,7 @@ enum SliceBoxingTaskMode { kSliceBoxingTaskModeAdd, }; -class SliceBoxingTaskNode final : public TaskNode { +class SliceBoxingTaskNode final : public TransportTaskNode { public: OF_DISALLOW_COPY_AND_MOVE(SliceBoxingTaskNode); SliceBoxingTaskNode() = default; @@ -49,10 +49,10 @@ class SliceBoxingTaskNode final : public TaskNode { void InferProducedDataRegstTimeShape() override; OperatorConf GetBoxingOpConf(); void InitProducedRegstMemCase(MemoryCase*) override; + int64_t MemZoneId121() const override { return mem_zone_id_; } HashMap in_data_edge2slice_; std::vector ordered_in_data_edges_; - LogicalBlobId lbi_; TensorSliceView out_slice_; Shape out_shape_; SliceBoxingTaskMode mode_ = kSliceBoxingTaskModeInvalid; diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 97c5355e48..ec11176719 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -395,32 +395,24 @@ TaskGraph::TaskGraph() { sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this)); boxing_logger_ = CreateBoxingLogger(); hierarchical_sub_tsk_gph_builder_.reset(new DispatchHierarchicalSubTskGphBuilder()); - HashMap> logical2sorted_comp_tasks; - HashMap>> buf_task; - auto MutBufTask = [&](CompTaskNode* task_node, int64_t machine_id, int32_t mem_zone_id) { - auto& buf_vec = buf_task[task_node][machine_id]; - if (buf_vec.empty()) { - buf_vec.assign(Global::Get()->MemZoneNum(), nullptr); - } - return &(buf_vec.at(mem_zone_id)); - }; + HashMap> op_node2sorted_comp_tasks; op_graph->ForEachNode([&](const OpNode* op_node) { - std::vector* sorted_comp_tasks = &(logical2sorted_comp_tasks[op_node]); + std::vector* sorted_comp_tasks = &(op_node2sorted_comp_tasks[op_node]); GenSortedCompTaskNodes(op_node, sorted_comp_tasks); for (CompTaskNode* comp_task : *sorted_comp_tasks) { AddAllocatedNode(comp_task); } }); op_graph->ForEachEdge([&](const OpEdge* op_edge) { - // TODO(chengcheng): ForEachLbi for one regst one blob. BldSubTskGphMthd method = GetMthdForBldSubTskGph(op_edge); - (this->*method)(op_edge, logical2sorted_comp_tasks.at(op_edge->src_node()), - logical2sorted_comp_tasks.at(op_edge->dst_node()), MutBufTask); + (this->*method)(op_edge, op_node2sorted_comp_tasks.at(op_edge->src_node()), + op_node2sorted_comp_tasks.at(op_edge->dst_node())); }); + ForEachOpGraphNecessaryCtrlEdge( op_graph, [&](const OpNode* src, const OpNode* dst, int64_t ctrl_regst_num) { - const auto& src_task_nodes = logical2sorted_comp_tasks.at(src); - const auto& dst_task_nodes = logical2sorted_comp_tasks.at(dst); + const auto& src_task_nodes = op_node2sorted_comp_tasks.at(src); + const auto& dst_task_nodes = op_node2sorted_comp_tasks.at(dst); if (src->op().op_conf().has_src_subset_tick_conf()) { UNIMPLEMENTED(); } else if (dst->op().op_conf().has_dst_subset_tick_conf()) { @@ -436,17 +428,82 @@ TaskGraph::TaskGraph() { TaskGraph::~TaskGraph() = default; -Maybe TaskGraph::ConnectDstSubsetTickEdges(const std::vector& src_task_nodes, - const std::vector& dst_task_nodes) { - std::function(int64_t mchn_id, int64_t thrd_id)> TaskNode4MachineId7ThrdId; - JUST(MakeGetterTaskNode4MachineId7ThrdId(dst_task_nodes, &TaskNode4MachineId7ThrdId)); - for (CompTaskNode* src_task_node : src_task_nodes) { - CompTaskNode* dst_task_node = - JUST(TaskNode4MachineId7ThrdId(src_task_node->machine_id(), src_task_node->thrd_id())); - TaskEdge* edge = NewEdge(); - Connect(src_task_node, edge, dst_task_node); +TaskEdge* TaskGraph::NewTaskEdgeWithLbi(const LogicalBlobId& lbi) { + TaskEdge* edge = NewEdge(); + edge->AddLbi(lbi); + return edge; +} + +TaskEdge* TaskGraph::NewTaskEdgeWithLbis(const std::vector& lbis) { + TaskEdge* edge = NewEdge(); + edge->AddLbis(lbis); + return edge; +} + +TaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi, + int64_t dst_machine_id, int64_t dst_mem_zone_id) { + int64_t src_mem_zone_id = src_node->MemZoneId121(); + const ProxyKey key(src_node, lbi, dst_machine_id, dst_mem_zone_id); + if (proxy2node.find(key) != proxy2node.cend()) { + return proxy2node.at(key); + } else { + if (dst_machine_id == src_node->machine_id() && dst_mem_zone_id == src_mem_zone_id) { + proxy2node[key] = src_node; + return src_node; + } else if (Global::Get()->IsGpuMemZone(dst_mem_zone_id)) { + TaskNode* proxy_on_dst_host = + GetProxyNode(src_node, lbi, dst_machine_id, Global::Get()->CpuMemZoneId()); + CopyHdTaskNode* copy_task = NewNode(); + copy_task->Init(CopyHdOpConf::H2D, proxy_on_dst_host->machine_id(), + Global::Get()->GetGpuPhyIdFromMemZoneId(dst_mem_zone_id), lbi); + Connect(proxy_on_dst_host, NewTaskEdgeWithLbi(lbi), copy_task); + proxy2node[key] = copy_task; + return copy_task; + } else if (Global::Get()->IsCpuMemZone(dst_mem_zone_id)) { + if (src_node->machine_id() == dst_machine_id) { + if (Global::Get()->IsGpuMemZone(src_mem_zone_id)) { + CopyHdTaskNode* copy_task = NewNode(); + copy_task->Init(CopyHdOpConf::D2H, src_node->machine_id(), + Global::Get()->GetGpuPhyIdFromMemZoneId(src_mem_zone_id), lbi); + Connect(src_node, NewTaskEdgeWithLbi(lbi), copy_task); + proxy2node[key] = copy_task; + return copy_task; + } else { + UNIMPLEMENTED(); + } + } else { + TaskNode* proxy_on_src_host = GetProxyNode(src_node, lbi, src_node->machine_id(), + Global::Get()->CpuMemZoneId()); + CopyCommNetTaskNode* copy_comm_net_task = NewNode(); + copy_comm_net_task->Init(dst_machine_id, lbi); + Connect(proxy_on_src_host, NewTaskEdgeWithLbi(lbi), copy_comm_net_task); + proxy2node[key] = copy_comm_net_task; + return copy_comm_net_task; + } + } else { + UNIMPLEMENTED(); + } } - return Maybe::Ok(); + UNIMPLEMENTED(); + return nullptr; +} + +TaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi, + const ParallelDesc& dst_parallel_desc, int64_t dst_parallel_id) { + const int64_t dst_machine_id = + CHECK_JUST(dst_parallel_desc.MachineId4ParallelId(dst_parallel_id)); + int64_t dst_mem_zone_id; + const IDMgr* id_mgr = Global::Get(); + if (dst_parallel_desc.device_type() == DeviceType::kCPU) { + dst_mem_zone_id = id_mgr->CpuMemZoneId(); + } else if (dst_parallel_desc.device_type() == DeviceType::kGPU) { + const int64_t dst_dev_phy_id = + CHECK_JUST(dst_parallel_desc.DeviceId4ParallelId(dst_parallel_id)); + dst_mem_zone_id = id_mgr->GpuMemZoneId(dst_dev_phy_id); + } else { + UNIMPLEMENTED(); + } + return GetProxyNode(src_node, lbi, dst_machine_id, dst_mem_zone_id); } void TaskGraph::ConnectCtrlEdges(const std::vector& src_task_nodes, @@ -640,17 +697,7 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) { const OpNode* src_op_node = op_edge->src_node(); const OpNode* dst_op_node = op_edge->dst_node(); for (const LogicalBlobId& lbi : op_edge->lbis()) { - std::vector in_nodes; - if (op_edge->lbis().size() == 1) { - in_nodes.assign(sorted_src_comp_tasks.begin(), sorted_src_comp_tasks.end()); - } else { - for (CompTaskNode* src_node : sorted_src_comp_tasks) { - auto* identity_node = NewNode(); - identity_node->Init(src_node->machine_id(), src_node->thrd_id(), lbi); - Connect(src_node, NewEdge(), identity_node); - in_nodes.push_back(identity_node); - } - } + std::vector in_nodes(sorted_src_comp_tasks.begin(), sorted_src_comp_tasks.end()); std::vector out_nodes; out_nodes.reserve(sorted_dst_comp_tasks.size()); std::vector> sorted_ctrl_tasks; @@ -668,7 +715,10 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) { boxing_logger_->Log(*status, src_op_node->op().op_name(), dst_op_node->op().op_name(), src_parallel_desc, dst_parallel_desc, src_parallel_distribution, dst_parallel_distribution, lbi, blob_desc); - sub_tsk_gph_builder_ctx_->ConnectAll121(out_nodes, sorted_dst_comp_tasks); + CHECK_EQ(out_nodes.size(), sorted_dst_comp_tasks.size()); + FOR_RANGE(size_t, i, 0, out_nodes.size()) { + ConnectWithLbi(out_nodes.at(i), sorted_dst_comp_tasks.at(i), lbi); + } if (!sorted_ctrl_tasks.empty()) { CHECK_EQ(sorted_ctrl_tasks.size(), sorted_dst_comp_tasks.size()); FOR_RANGE(size_t, i, 0, sorted_dst_comp_tasks.size()) { @@ -687,9 +737,9 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) { DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByOneToOne) { CHECK_EQ(sorted_src_comp_tasks.size(), sorted_dst_comp_tasks.size()); FOR_RANGE(size_t, i, 0, sorted_src_comp_tasks.size()) { - CompTaskNode* src = sorted_src_comp_tasks.at(i); - CompTaskNode* dst = sorted_dst_comp_tasks.at(i); - BuildTaskPath(src, dst, MutBufTask, true); + for (const LogicalBlobId& lbi : op_edge->lbis()) { + BuildTaskPath(sorted_src_comp_tasks.at(i), sorted_dst_comp_tasks.at(i), lbi); + } } } @@ -698,7 +748,9 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBroadcastToBroadcast) { CompTaskNode* nearest_src_node = SubTskGphBuilderUtil::FindNearestNode(sorted_src_comp_tasks, dst_node); CHECK_NOTNULL(nearest_src_node); - BuildTaskPath(nearest_src_node, dst_node, MutBufTask, true); + for (const LogicalBlobId& lbi : op_edge->lbis()) { + BuildTaskPath(nearest_src_node, dst_node, lbi); + } } } @@ -712,7 +764,7 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialInLbiConnect) { FOR_RANGE(int, i, 0, sorted_dst_comp_tasks.size()) { const auto& lbi = dst_op.BnInOp2Lbi(dst_op.input_bns().Get(i)); if (lbis.find(lbi) != lbis.end()) { - BuildTaskPath(sorted_src_comp_tasks.at(0), sorted_dst_comp_tasks.at(i), MutBufTask, true); + BuildTaskPath(sorted_src_comp_tasks.at(0), sorted_dst_comp_tasks.at(i), lbi); } } } @@ -727,30 +779,31 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialOutLbiConnect) { FOR_RANGE(int, i, 0, sorted_src_comp_tasks.size()) { const auto& lbi = src_op.BnInOp2Lbi(src_op.output_bns().Get(i)); if (lbis.find(lbi) != lbis.end()) { - BuildTaskPath(sorted_src_comp_tasks.at(i), sorted_dst_comp_tasks.at(0), MutBufTask, true); + BuildTaskPath(sorted_src_comp_tasks.at(i), sorted_dst_comp_tasks.at(0), lbi); } } } -Maybe TaskGraph::ConnectSrcSubsetTickEdges(const std::vector& src_task_nodes, - const std::vector& dst_task_nodes) { +DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphBySrcSubsetConnect) { std::function(int64_t mchn_id, int64_t thrd_id)> TaskNode4MachineId7ThrdId; - JUST(MakeGetterTaskNode4MachineId7ThrdId(src_task_nodes, &TaskNode4MachineId7ThrdId)); - for (CompTaskNode* dst_task_node : dst_task_nodes) { - CompTaskNode* src_task_node = - JUST(TaskNode4MachineId7ThrdId(dst_task_node->machine_id(), dst_task_node->thrd_id())); - TaskEdge* edge = NewEdge(); - Connect(src_task_node, edge, dst_task_node); + CHECK_JUST( + MakeGetterTaskNode4MachineId7ThrdId(sorted_src_comp_tasks, &TaskNode4MachineId7ThrdId)); + for (CompTaskNode* dst_task_node : sorted_dst_comp_tasks) { + CompTaskNode* src_task_node = CHECK_JUST( + TaskNode4MachineId7ThrdId(dst_task_node->machine_id(), dst_task_node->thrd_id())); + Connect(src_task_node, NewTaskEdgeWithLbis(op_edge->lbis()), dst_task_node); } - return Maybe::Ok(); -} - -DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphBySrcSubsetConnect) { - CHECK_JUST(ConnectSrcSubsetTickEdges(sorted_src_comp_tasks, sorted_dst_comp_tasks)); } DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByDstSubsetConnect) { - CHECK_JUST(ConnectDstSubsetTickEdges(sorted_src_comp_tasks, sorted_dst_comp_tasks)); + std::function(int64_t mchn_id, int64_t thrd_id)> TaskNode4MachineId7ThrdId; + CHECK_JUST( + MakeGetterTaskNode4MachineId7ThrdId(sorted_dst_comp_tasks, &TaskNode4MachineId7ThrdId)); + for (CompTaskNode* src_task_node : sorted_src_comp_tasks) { + CompTaskNode* dst_task_node = CHECK_JUST( + TaskNode4MachineId7ThrdId(src_task_node->machine_id(), src_task_node->thrd_id())); + Connect(src_task_node, NewTaskEdgeWithLbis(op_edge->lbis()), dst_task_node); + } } DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphNormalForwardToDecodeH2D) { @@ -758,103 +811,30 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphNormalForwardToDecodeH2D) { FOR_RANGE(size_t, i, 0, sorted_src_comp_tasks.size()) { CompTaskNode* src = sorted_src_comp_tasks.at(i); CompTaskNode* dst = sorted_dst_comp_tasks.at(i); - Connect(src, NewEdge(), dst); + for (const LogicalBlobId& lbi : op_edge->lbis()) { BuildTaskPath(src, dst, lbi); } } } -void TaskGraph::BuildTaskPath( - CompTaskNode* src, CompTaskNode* dst, - std::function - MutBufTask, - bool use_buf_task_node) { - CHECK_NE(src, dst); - auto GetBufTask = [&](int64_t machine_id, int32_t mem_zone_id) { - return *MutBufTask(src, machine_id, mem_zone_id); - }; - auto SetBufTask = [&](int64_t machine_id, int32_t mem_zone_id, TaskNode* new_val) { - TaskNode** cur_val = MutBufTask(src, machine_id, mem_zone_id); - if (*cur_val == nullptr) { - *cur_val = new_val; - } else { - CHECK_EQ(*cur_val, new_val); - } - return new_val; - }; - TaskNode* cur_node = src; - while (cur_node->machine_id() != dst->machine_id() - || cur_node->MemZoneId121() != dst->MemZoneId121()) { - cur_node = BuildTaskStep(cur_node, dst, GetBufTask, SetBufTask, use_buf_task_node); - } - if (cur_node != dst) { Connect(cur_node, NewEdge(), dst); } -} - -TaskNode* TaskGraph::BuildTaskStep( - TaskNode* cur_node, TaskNode* dst, - const std::function& GetBufTask, - const std::function& SetBufTask, - bool use_buf_task_node) { - int32_t cpu_mem_zone_id = Global::Get()->CpuMemZoneId(); - int32_t next_mem_zone_id = -1; - TaskNode* next_node = nullptr; - if (cur_node->MemZoneId121() != cpu_mem_zone_id) { - next_mem_zone_id = cpu_mem_zone_id; - if (!use_buf_task_node || !(next_node = GetBufTask(cur_node->machine_id(), next_mem_zone_id))) { - next_node = AddCopyD2HTaskFrom(cur_node); - Connect(cur_node, NewEdge(), next_node); - } - } else if (cur_node->machine_id() == dst->machine_id()) { - next_mem_zone_id = dst->MemZoneId121(); - if (!use_buf_task_node || !(next_node = GetBufTask(cur_node->machine_id(), next_mem_zone_id))) { - next_node = TryAddCopyH2DTaskTo(dst); - if (next_node == nullptr) { next_node = dst; } - Connect(cur_node, NewEdge(), next_node); +void TaskGraph::ConnectWithLbi(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi) { + if (src_node == dst_node) { return; } + for (TaskEdge* out_edge : src_node->out_edges()) { + TaskNode* out_node = out_edge->dst_node(); + if (out_node == dst_node) { + out_edge->AddLbi(lbi); + return; } - } else if (cur_node->machine_id() != dst->machine_id()) { - next_mem_zone_id = cpu_mem_zone_id; - if (!use_buf_task_node || !(next_node = GetBufTask(dst->machine_id(), next_mem_zone_id))) { - next_node = AddCopyCommNetTaskBetween(cur_node, dst); - Connect(cur_node, NewEdge(), next_node); - } - } else { - UNIMPLEMENTED(); - } - if (use_buf_task_node && (next_node != dst)) { - SetBufTask(next_node->machine_id(), next_mem_zone_id, next_node); } - return next_node; -} -TaskNode* TaskGraph::TryAddCopyH2DTaskTo(TaskNode* task) { - if (IsInterfaceTask(task)) { return nullptr; } - if (IsClassRegistered(task->GetTaskType())) { return nullptr; } - CHECK_EQ(task->device_type(), DeviceType::kGPU); - CopyHdTaskNode* copy_task = NewNode(); - copy_task->Init(CopyHdOpConf::H2D, task->machine_id(), task->GpuPhyId()); - return copy_task; + TaskEdge* connected_edge = NewEdge(); + connected_edge->AddLbi(lbi); + Connect(src_node, connected_edge, dst_node); } -TaskNode* TaskGraph::AddCopyD2HTaskFrom(TaskNode* task) { - CHECK_EQ(task->device_type(), DeviceType::kGPU); - CopyHdTaskNode* copy_task = NewNode(); - copy_task->Init(CopyHdOpConf::D2H, task->machine_id(), task->GpuPhyId()); - return copy_task; -} - -TaskNode* TaskGraph::AddCopyCommNetTaskBetween(TaskNode* src, TaskNode* dst) { - CHECK_NE(src->machine_id(), dst->machine_id()); - CopyCommNetTaskNode* copy_comm_net_task = NewNode(); - copy_comm_net_task->Init(dst->machine_id()); - return copy_comm_net_task; -} - -void TaskGraph::ConnectWithCopyCommNetIfNeed(TaskNode* src, TaskNode* dst) { - if (src->machine_id() == dst->machine_id()) { - Connect(src, NewEdge(), dst); - } else { - TaskNode* copy_comm_net_task = AddCopyCommNetTaskBetween(src, dst); - Connect(src, NewEdge(), copy_comm_net_task); - Connect(copy_comm_net_task, NewEdge(), dst); - } +void TaskGraph::BuildTaskPath(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi) { + int64_t dst_machine_id = dst_node->machine_id(); + int64_t dst_mem_zone_id = dst_node->MemZoneId121(); + TaskNode* proxy_node = GetProxyNode(src_node, lbi, dst_machine_id, dst_mem_zone_id); + ConnectWithLbi(proxy_node, dst_node, lbi); } } // namespace oneflow diff --git a/oneflow/core/graph/task_graph.h b/oneflow/core/graph/task_graph.h index cfd39cffe6..81cc831940 100644 --- a/oneflow/core/graph/task_graph.h +++ b/oneflow/core/graph/task_graph.h @@ -30,11 +30,9 @@ namespace oneflow { class SubTskGphBuilderCtx; class HierarchicalSubTskGphBuilder; -#define BLD_SUB_TSK_GPH_MTHD_ARGS() \ - (const OpEdge* op_edge, const std::vector& sorted_src_comp_tasks, \ - const std::vector& sorted_dst_comp_tasks, \ - std::function \ - MutBufTask) +#define BLD_SUB_TSK_GPH_MTHD_ARGS() \ + (const OpEdge* op_edge, const std::vector& sorted_src_comp_tasks, \ + const std::vector& sorted_dst_comp_tasks) class TaskGraph; using BldSubTskGphMthd = void(TaskGraph::*) BLD_SUB_TSK_GPH_MTHD_ARGS(); @@ -53,6 +51,17 @@ class TaskGraph final : public Graph { void EnableInplaceMemSharing(const std::function& IsOpNameDataOrCtrlReachable); + TaskNode* GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi, int64_t dst_machine_id, + int64_t dst_mem_zone_id); + + TaskNode* GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi, + const ParallelDesc& dst_parallel_desc, int64_t dst_parallel_id); + + TaskEdge* NewTaskEdgeWithLbi(const LogicalBlobId& lbi); + TaskEdge* NewTaskEdgeWithLbis(const std::vector& lbis); + + void ConnectWithLbi(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi); + #define DECLARE_BLD_SUB_TASK_GRAPH_METHOD(method_name) void method_name BLD_SUB_TSK_GPH_MTHD_ARGS(); DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing); @@ -65,26 +74,8 @@ class TaskGraph final : public Graph { DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphNormalForwardToDecodeH2D); private: - void BuildTaskPath( - CompTaskNode* src, CompTaskNode* dst, - std::function - MutBufTask, - bool use_buf_task_node); - TaskNode* BuildTaskStep( - TaskNode* cur_node, TaskNode* dst, - const std::function& GetBufTask, - const std::function& - SetBufTask, - bool use_buf_task_node); - - TaskNode* TryAddCopyH2DTaskTo(TaskNode*); - TaskNode* AddCopyD2HTaskFrom(TaskNode*); - TaskNode* AddCopyCommNetTaskBetween(TaskNode* src, TaskNode* dst); - void ConnectWithCopyCommNetIfNeed(TaskNode* src, TaskNode* dst); - Maybe ConnectSrcSubsetTickEdges(const std::vector& src_task_nodes, - const std::vector& dst_task_nodes); - Maybe ConnectDstSubsetTickEdges(const std::vector& src_task_nodes, - const std::vector& dst_task_nodes); + void BuildTaskPath(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi); + void ConnectCtrlEdges(const std::vector& src_task_nodes, const std::vector& dst_task_nodes, int64_t ctrl_regst_num); @@ -109,6 +100,30 @@ class TaskGraph final : public Graph { std::unique_ptr hierarchical_sub_tsk_gph_builder_; std::unique_ptr sub_tsk_gph_builder_ctx_; std::unique_ptr boxing_logger_; + + struct ProxyKey { + TaskNode* src_node; + LogicalBlobId lbi; + int64_t dst_machine_id; + int64_t dst_mem_zone_id; + + ProxyKey(TaskNode* src, const LogicalBlobId& arg_lbi, int64_t arg_machine, int64_t arg_zone) + : src_node(src), lbi(arg_lbi), dst_machine_id(arg_machine), dst_mem_zone_id(arg_zone) {} + + bool operator==(const ProxyKey& other) const { + return src_node == other.src_node && lbi == other.lbi + && dst_machine_id == other.dst_machine_id && dst_mem_zone_id == other.dst_mem_zone_id; + } + + struct Hasher { + inline size_t operator()(const ProxyKey& key) const { + return std::hash{}(key.src_node) ^ std::hash{}(key.lbi) + ^ key.dst_machine_id ^ key.dst_mem_zone_id; + } + }; + }; + + HashMap proxy2node; }; } // namespace oneflow diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp index 6a70e2b00b..852d15f089 100644 --- a/oneflow/core/graph/task_node.cpp +++ b/oneflow/core/graph/task_node.cpp @@ -354,7 +354,13 @@ void TaskNode::TryLockConsumedRegst(const std::string& name) { } void TaskNode::LockRegsts() { - for (auto& pair : produced_regsts_) { pair.second->Lock(); } + for (auto& pair : produced_regsts_) { + std::shared_ptr regst = pair.second; + regst->Lock(); + + // NOTE(chengcheng): CHECK 1 regst 1 blob. + if (regst->regst_desc_type().has_data_regst_desc()) { CHECK_LE(regst->NumOfLbi(), 1); } + } } void TaskNode::UpdateTaskId() { @@ -397,6 +403,30 @@ void TaskEdge::AddRegst(const std::string& name_in_producer, CHECK(name_in_producer2regst_.emplace(name_in_producer, regst).second); } +void TaskEdge::CheckRegstLbiValid() const { + HashMap> lbi2data_regst; + for (auto& pair : name_in_producer2regst_) { + std::shared_ptr regst = pair.second; + if (regst->regst_desc_type().has_data_regst_desc()) { + // NOTE(chengcheng): regst_desc_type is Set, BUT regst_desc_type.data_regst_desc is UNSET! + // So you can ONLY use NumOfLbi and ForEachLbi interface. + CHECK_EQ(regst->NumOfLbi(), 1); + regst->ForEachLbi( + [&](const LogicalBlobId& lbi) { CHECK(lbi2data_regst.emplace(lbi, regst).second); }); + } + } + + CHECK_EQ(lbi2data_regst.size(), lbis_.size()) + << " \n\n TaskEdge lbi and regst NOT match." + << " TaskEdge: edge_id = " << edge_id() << " From: [" << src_node()->VisualStr() << "] To: [" + << dst_node()->VisualStr() << "]\n"; + for (auto& lbi : lbis_) { + CHECK(lbi2data_regst.find(lbi) != lbi2data_regst.end()) + << " \n\n Cannot find lbi: " << lbi.DebugString() << " in TaskEdge From: [" + << src_node()->VisualStr() << "] To: [" << dst_node()->VisualStr() << "]\n\n"; + } +} + RegstDescProto* FindOrCreateProducedCtrlRegstDesc(TaskProto* task_proto, const std::string& regst_desc_name) { auto* produced_regst_desc = task_proto->mutable_produced_regst_desc(); diff --git a/oneflow/core/graph/task_node.h b/oneflow/core/graph/task_node.h index 80aadc3ba6..e3ed660003 100644 --- a/oneflow/core/graph/task_node.h +++ b/oneflow/core/graph/task_node.h @@ -169,10 +169,16 @@ class TaskEdge final : public Edge { std::shared_ptr GetRegst(const std::string& name_in_producer) const; std::shared_ptr GetSoleRegst() const; std::vector> GetRegsts() const; + const HashSet& GetLbis() const { return lbis_; } void AddRegst(const std::string& name_in_producer, const std::shared_ptr& regst); + void AddLbi(const LogicalBlobId& lbi) { lbis_.insert(lbi); } + void AddLbis(const std::vector& lbis) { lbis_.insert(lbis.begin(), lbis.end()); } + + void CheckRegstLbiValid() const; private: + HashSet lbis_; HashMap> name_in_producer2regst_; }; diff --git a/oneflow/core/graph/transport_task_node.h b/oneflow/core/graph/transport_task_node.h new file mode 100644 index 0000000000..75dbc6b722 --- /dev/null +++ b/oneflow/core/graph/transport_task_node.h @@ -0,0 +1,39 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_GRAPH_TRANSPORT_TASK_NODE_H_ +#define ONEFLOW_CORE_GRAPH_TRANSPORT_TASK_NODE_H_ + +#include "oneflow/core/graph/task_node.h" +#include "oneflow/core/register/logical_blob_id.pb.h" + +namespace oneflow { + +class TransportTaskNode : public TaskNode { + public: + OF_DISALLOW_COPY_AND_MOVE(TransportTaskNode); + TransportTaskNode() = default; + virtual ~TransportTaskNode() = default; + + void set_lbi(const LogicalBlobId& lbi) { lbi_ = lbi; } + LogicalBlobId lbi() const { return lbi_; } + + private: + LogicalBlobId lbi_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_GRAPH_TRANSPORT_TASK_NODE_H_ diff --git a/oneflow/core/graph_impl/normal_forward_compute_task_node.cpp b/oneflow/core/graph_impl/normal_forward_compute_task_node.cpp index f768dad232..8d3ef0f2d1 100644 --- a/oneflow/core/graph_impl/normal_forward_compute_task_node.cpp +++ b/oneflow/core/graph_impl/normal_forward_compute_task_node.cpp @@ -32,27 +32,12 @@ size_t RegstNum4OpSameOutputBlob(OperatorConf::OpTypeCase op_type_case) { } } -std::string GetOutRegstNameByObn(const std::string& obn) { - return "NormalForwardCompTaskNodeOutRegstName_" + obn; -} +std::string GetOutRegstNameByObn(const std::string& obn) { return "__" + obn; } } // namespace bool NormalForwardCompTaskNode::HasBackwardCompTaskNode() { return false; } -bool NormalForwardCompTaskNode::CanProduceSeperatedRegstsForEachOutBlob() const { - return op()->output_bns().size() > 1 && IsAllOutNodeNormalForward(); -} - -bool NormalForwardCompTaskNode::IsAllOutNodeNormalForward() const { - bool ret = true; - ForEachNodeOnOutDataEdge([&](TaskNode* node) { - auto* fw_node = dynamic_cast(node); - if (fw_node == nullptr) { ret = false; } - }); - return ret; -} - void NormalForwardCompTaskNode::ProduceOutRegstByNameAndBlockNum(const std::string& name, size_t mem_block_num) { if (mem_block_num != -1) { @@ -76,33 +61,21 @@ void NormalForwardCompTaskNode::ProduceAllRegstsAndBindEdges() { } // when output blob num > 1 and task node on out edge is all NormalForwardCompTaskNode , // create multi out regst by output blob name in op - if (CanProduceSeperatedRegstsForEachOutBlob()) { - HashMap lbi2out_regst_name; - for (const std::string& obn : sole_op->output_bns()) { - const LogicalBlobId& lbi = sole_op->BnInOp2Lbi(obn); - std::string out_regst_name = GetOutRegstNameByObn(obn); - lbi2out_regst_name.insert({lbi, out_regst_name}); - ProduceOutRegstByNameAndBlockNum(out_regst_name, mem_block_num); - } - ForEachOutDataEdge([&](TaskEdge* edge) { - TaskNode* node = edge->dst_node(); - auto* dst_node = dynamic_cast(node); - CHECK(dst_node != nullptr) << "1regst1blob ONLY support normal fw comp task node 121"; - std::shared_ptr dst_op = dst_node->op(); - bool is_found = false; - for (const std::string& ibn : dst_op->input_bns()) { - const LogicalBlobId& dst_in_lbi = dst_op->BnInOp2Lbi(ibn); - if (lbi2out_regst_name.find(dst_in_lbi) != lbi2out_regst_name.end()) { - is_found = true; - BindEdgeWithProducedRegst(edge, lbi2out_regst_name.at(dst_in_lbi)); - } - } - CHECK(is_found) << "Cannot find comsumed blob in dst op: " << dst_op->op_name(); - }); - } else { - ProduceOutRegstByNameAndBlockNum("out", mem_block_num); - ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, "out"); }); + + HashMap lbi2out_regst_name; + for (const std::string& obn : sole_op->output_bns()) { + const LogicalBlobId& lbi = sole_op->BnInOp2Lbi(obn); + std::string out_regst_name = GetOutRegstNameByObn(obn); + lbi2out_regst_name.insert({lbi, out_regst_name}); + ProduceOutRegstByNameAndBlockNum(out_regst_name, mem_block_num); } + ForEachOutDataEdge([&](TaskEdge* edge) { + for (const LogicalBlobId& lbi : edge->GetLbis()) { + auto it = lbi2out_regst_name.find(lbi); + CHECK(it != lbi2out_regst_name.end()); + BindEdgeWithProducedRegst(edge, it->second); + } + }); ProduceRegst("tmp", true); } @@ -127,45 +100,21 @@ void NormalForwardCompTaskNode::BuildExecGphAndRegst() { } void NormalForwardCompTaskNode::BuildExecGphStructAndBindInRegst() { - HashMap> lbi2producer; ExecNode* cur_node = mut_exec_gph().NewNode(); cur_node->mut_op() = op(); - for (const std::string& obn : op()->output_bns()) { - const LogicalBlobId& lbi = op()->BnInOp2Lbi(obn); - CHECK(lbi2producer.insert({lbi, {cur_node, obn}}).second); - } const std::list>& in_regsts = GetConsumedRegst("in"); for (const std::string& ibn : cur_node->op()->input_bns()) { - const LogicalBlobId& lbi = cur_node->op()->BnInOp2Lbi(ibn); - auto producer_it = lbi2producer.find(lbi); - if (producer_it != lbi2producer.end()) { - ExecEdge* edge = mut_exec_gph().NewEdge(); - edge->set_lbi(lbi); - edge->mut_src_bn() = producer_it->second.second; - edge->mut_dst_bn() = ibn; - Connect(producer_it->second.first, edge, cur_node); - } else { - cur_node->BindBnWithOneOfTheRegsts(ibn, in_regsts); - } + cur_node->BindBnWithOneOfTheRegsts(ibn, in_regsts); } } void NormalForwardCompTaskNode::BuildOutRegst() { - if (CanProduceSeperatedRegstsForEachOutBlob()) { - ExecNode* exec_node = mut_exec_gph().SoleNode(); - for (const std::string& obn : exec_node->op()->output_bns()) { - std::string out_regst_name = GetOutRegstNameByObn(obn); - std::shared_ptr out_regst = GetProducedRegst(out_regst_name); - out_regst->AddLbi(exec_node->op()->BnInOp2Lbi(obn)); - exec_node->BindBnWithRegst(obn, out_regst); - } - } else { - std::shared_ptr out_regst = GetProducedRegst("out"); - ExecNode* exec_node = mut_exec_gph().SoleNode(); - for (const std::string& obn : exec_node->op()->output_bns()) { - out_regst->AddLbi(exec_node->op()->BnInOp2Lbi(obn)); - exec_node->BindBnWithRegst(obn, out_regst); - } + ExecNode* exec_node = mut_exec_gph().SoleNode(); + for (const std::string& obn : exec_node->op()->output_bns()) { + std::string out_regst_name = GetOutRegstNameByObn(obn); + std::shared_ptr out_regst = GetProducedRegst(out_regst_name); + out_regst->AddLbi(exec_node->op()->BnInOp2Lbi(obn)); + exec_node->BindBnWithRegst(obn, out_regst); } } diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index 84ea357d32..45a6277736 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -87,6 +87,8 @@ void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const { } task_gph->TopoForEachNode(&TaskNode::InferTimeShapeIfMeaningful); + task_gph->ForEachEdge([&](TaskEdge* task_edge) { task_edge->CheckRegstLbiValid(); }); + task_gph->ForEachNode([&](TaskNode* task_node) { if (task_node->IsMeaningLess()) { return; } task_node->ToProto(plan->mutable_task()->Add()); -- GitLab