/* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/task_graph.h" #include "oneflow/core/common/util.h" #include "oneflow/core/graph/inplace_lbi_graph.h" #include "oneflow/core/graph/id_serialization.h" #include "oneflow/core/register/blob_desc.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/operator/variable_op.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/graph/normal_forward_compute_task_node.h" #include "oneflow/core/graph/boxing_identity_task_node.h" #include "oneflow/core/job/scope.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/job_rewriter/calculation_pass.h" #include "oneflow/core/job/env_desc.h" #include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h" #include "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.h" #include "oneflow/core/graph/stream_index_getter_registry_manager.h" #include "oneflow/core/primitive/include/memcpy.h" namespace oneflow { namespace { bool IsMemcpyPrimitiveSupported(DeviceType device_type, primitive::MemcpyKind kind) { auto primitive = primitive::NewPrimitive(device_type, kind); return primitive.operator bool(); } bool IsMemcpyHtoDSupported(DeviceType device_type) { return IsMemcpyPrimitiveSupported(device_type, primitive::MemcpyKind::kHtoD); } bool IsMemcpyDtoHSupported(DeviceType device_type) { return IsMemcpyPrimitiveSupported(device_type, primitive::MemcpyKind::kDtoH); } bool IsConnectToTickOp(const TaskNode* node) { const auto* comp_task_node = dynamic_cast(node); if (comp_task_node == nullptr) { return false; } const Operator* op = comp_task_node->op().get(); if (dynamic_cast(op) != nullptr) { return true; } return false; } std::string GetOpConfCalculationPassName(const OperatorConf& op_conf) { CHECK(op_conf.has_scope_symbol_id()); int64_t scope_symbol_id = op_conf.scope_symbol_id(); CHECK(Global>::Get()->Has(scope_symbol_id)) << " Error! op : \n " << op_conf.DebugString() << " has error scope_symbol_id = " << scope_symbol_id << " which cannot find in Global>::Get()\n"; const Scope& scope = Global>::Get()->Get(scope_symbol_id); return scope.scope_proto().calculation_pass_name(); } bool IsOptimizerPassOp(const Operator* op) { // NOTE(chengcheng): use scope::calculation_pass_name instead of area_id to not merge optimizer // ops with fw/bw ops if (!op->op_conf().has_scope_symbol_id()) { // NOTE(chengcheng): Some system op insert to OpGraph may not set scope_symbol_id, it MUST NOT // optimizer subgraph ops. return false; } return GetOpConfCalculationPassName(op->op_conf()) == kOptimizerPass; } bool IsSubsetTickOpConf(const OperatorConf& op_conf) { return op_conf.has_src_subset_tick_conf() || op_conf.has_dst_subset_tick_conf(); } bool IsTickOpConf(const OperatorConf& conf) { return IsClassRegistered(conf.op_type_case()); } bool IsSpecialOpNotConsiderMergeInChain(const Operator* op) { const OperatorConf& op_conf = op->op_conf(); if (op_conf.has_variable_conf() || op_conf.has_tick_conf() || op_conf.has_device_tick_conf() || op_conf.has_src_subset_tick_conf() || op_conf.has_dst_subset_tick_conf() || op_conf.has_source_tick_conf() || op_conf.has_sink_tick_conf() || op_conf.has_acc_tick_conf()) { return true; } if (op_conf.has_user_conf()) { const std::string& user_type_name = op_conf.user_conf().op_type_name(); if (user_type_name == "repeat" || user_type_name == "acc" || user_type_name == "pack" || user_type_name == "unpack" || user_type_name == "identity_buffer") { return true; } } // NOTE(chengcheng): ONLY nccl_use_compute_stream = false will exclude optimizer pass ops if (!Global::Get()->nccl_use_compute_stream() && IsOptimizerPassOp(op)) { return true; } return false; } bool IsTaskNodeProducedResgtHasMultiRegstNum(const TaskNode* node) { for (const auto& pair : node->produced_regsts()) { if (pair.second->min_register_num() > 1) { return true; } } return false; } bool CanBeMergedInChain(const TaskNode* node) { // ONLY the node which is NormalForward and in GPU and NOT variable can be merged. if (IsTaskNodeProducedResgtHasMultiRegstNum(node)) { return false; } const auto* fw_comp_node = dynamic_cast(node); if (fw_comp_node == nullptr) { return false; } if (fw_comp_node->device_type() != DeviceType::kGPU) { return false; } const Operator* op = fw_comp_node->op().get(); if (IsSpecialOpNotConsiderMergeInChain(op)) { return false; } return true; } std::shared_ptr GetTaskNodeTimeShape(const TaskNode* node) { const auto* fw_comp_node = dynamic_cast(node); CHECK(fw_comp_node != nullptr); return CHECK_JUST(fw_comp_node->op()->GetOpTimeShape()); } void TraverseConnectedSubGraphMergeInThisChain(TaskNode* this_node, const int64_t this_chain_id) { CHECK_NE(this_chain_id, -1); CHECK_EQ(this_node->chain_id(), -1); // bfs search all node can be merged in this chain std::shared_ptr seed_time_shape = GetTaskNodeTimeShape(this_node); HashSet visited_nodes; std::queue queued_nodes; queued_nodes.push(this_node); visited_nodes.insert(this_node); while (!queued_nodes.empty()) { TaskNode* cur_node = queued_nodes.front(); queued_nodes.pop(); CHECK_EQ(cur_node->chain_id(), -1); cur_node->set_chain_id(this_chain_id); cur_node->ForEachNodeOnInOutDataEdge([&](TaskNode* next_node) { if (visited_nodes.find(next_node) == visited_nodes.end() && CanBeMergedInChain(next_node) && this_node->thrd_id() == next_node->thrd_id() && (*GetTaskNodeTimeShape(next_node)) == (*seed_time_shape)) { if (next_node->chain_id() == -1) { queued_nodes.push(next_node); visited_nodes.insert(next_node); } else { CHECK_EQ(next_node->chain_id(), this_chain_id); } } }); } } std::function MakeGetterTaskNode4SoleOpName( const HashSet& task_nodes) { auto op_name2task_nodes = std::make_shared>>(); for (TaskNode* task_node : task_nodes) { if (task_node->exec_gph().node_num() == 1) { ExecNode* exec_node = task_node->exec_gph().SoleNode(); CHECK((*op_name2task_nodes)[exec_node->op()->op_name()].emplace(task_node).second); } } return [op_name2task_nodes](const std::string& op_name) -> TaskNode* { const auto& iter = op_name2task_nodes->find(op_name); if (iter == op_name2task_nodes->end()) { return nullptr; } if (iter->second.size() > 1) { return nullptr; } return *iter->second.begin(); }; } bool IsLbiOnTaskEdge(const TaskEdge* edge, const LogicalBlobId& lbi) { for (const auto& regst_desc : edge->GetRegsts()) { if (regst_desc->HasLbi(lbi)) { return true; } } return false; } std::function MakePredicatorIsLbiAllConsumersReachable( const std::function& TaskNode4SoleOpName, const std::function& IsOpNameDataOrCtrlReachable) { auto IsDataOrCtrlReachable = [IsOpNameDataOrCtrlReachable](const TaskNode* src_node, const TaskNode* dst_node) -> bool { if (src_node->chain_id() == dst_node->chain_id() && src_node->order_in_graph() <= dst_node->order_in_graph()) { return true; } const CompTaskNode* comp_src_node = dynamic_cast(src_node); if (comp_src_node == nullptr) { return false; } const CompTaskNode* comp_dst_node = dynamic_cast(dst_node); if (comp_dst_node == nullptr) { return false; } return IsOpNameDataOrCtrlReachable(comp_src_node->op()->op_name(), comp_dst_node->op()->op_name()); }; return [TaskNode4SoleOpName, IsDataOrCtrlReachable](const LogicalBlobId& lbi, const std::string& op_name) -> bool { const TaskNode* src_task_node = TaskNode4SoleOpName(lbi.op_name()); const TaskNode* dst_task_node = TaskNode4SoleOpName(op_name); size_t out_edges_size = 0; size_t reachable_out_edges_size = 0; for (TaskEdge* out_edge : src_task_node->out_edges()) { if (IsLbiOnTaskEdge(out_edge, lbi)) { out_edges_size += 1; reachable_out_edges_size += IsDataOrCtrlReachable(out_edge->dst_node(), dst_task_node); } } return out_edges_size > 0 && out_edges_size == reachable_out_edges_size; }; } bool IsInplaceAllowed( TaskNode* task_node, const std::vector& bns, const std::function& TaskNode4SoleOpName) { if (task_node->exec_gph().node_num() != 1) { return false; } const auto& exec_node = *task_node->exec_gph().SoleNode(); for (const auto& bn : bns) { // TaskNode for bn is not nullptr if it's on the same device with `task_node` if (TaskNode4SoleOpName(exec_node.op()->BnInOp2Lbi(bn).op_name()) == nullptr) { return false; } const RegstDesc& regst_desc = *exec_node.RegstDesc4BnInOp(bn); if (regst_desc.NumOfLbi() != 1) { return false; } } const BlobDesc* first_blob = nullptr; for (const auto& bn : bns) { const BlobDesc* blob_desc = exec_node.RegstDesc4BnInOp(bn)->SoleBlobDesc(); if (first_blob == nullptr) { first_blob = blob_desc; } else { if (!(first_blob->shape().elem_cnt() == blob_desc->shape().elem_cnt() && first_blob->data_type() == blob_desc->data_type())) { return false; } } } return true; } std::unique_ptr CreateBoxingLogger() { if (Global::Get()->enable_debug_mode()) { return std::unique_ptr( new CsvBoxingLogger(StrCat("boxing/log/", GlobalJobDesc().job_id()) + ".csv")); } else { return std::unique_ptr(new NullBoxingLogger()); } } Maybe MakeGetterTaskNode4MachineId7ThrdId( const std::vector& task_nodes, std::function(int64_t mchn_id, int64_t thrd_id)>* Getter) { // ticks are shared within a machine/process auto machine_id2task_node = std::make_shared>(); for (auto* task_node : task_nodes) { machine_id2task_node->emplace(task_node->machine_id(), task_node); } *Getter = [machine_id2task_node](int64_t mchn_id, int64_t thrd_id) -> Maybe { const auto& iter = machine_id2task_node->find(mchn_id); CHECK_OR_RETURN(iter != machine_id2task_node->end()); return iter->second; }; return Maybe::Ok(); } void GenSortedCompTaskNodes(const OpNode* op_node, std::vector* sorted_comp_tasks) { int64_t parallel_idx = 0; const ParallelDesc& parallel_desc = op_node->parallel_desc(); int64_t parallel_num = parallel_desc.parallel_num(); for (int64_t machine_id : parallel_desc.sorted_machine_ids()) { for (int64_t dev_phy_id : parallel_desc.sorted_dev_phy_ids(machine_id)) { CompTaskNode* comp_task_node = NewCompTaskNode4OpNode(op_node); comp_task_node->set_machine_id(machine_id); comp_task_node->mut_parallel_ctx()->set_parallel_id(parallel_idx++); comp_task_node->mut_parallel_ctx()->set_parallel_num(parallel_num); DeviceId::device_index_t device_index = parallel_desc.device_type() == DeviceType::kCPU ? DeviceId::kCPUDeviceIndex : static_cast(dev_phy_id); DeviceId device_id{static_cast(machine_id), parallel_desc.device_type(), device_index}; StreamId::stream_index_t stream_index{}; if (op_node->op().op_conf().has_stream_index_hint()) { int32_t stream_index_hint = op_node->op().op_conf().stream_index_hint(); LOG(INFO) << "set op: " << op_node->op().op_name() << " to stream: " << stream_index_hint; stream_index = static_cast(stream_index_hint); } else { stream_index = StreamIndexGetterRegistryManager::Get().StreamIndex4DeviceIdAndTaskType( device_id, comp_task_node->GetTaskType()); } comp_task_node->set_thrd_id(SerializeStreamIdToInt64(StreamId{device_id, stream_index})); comp_task_node->set_op_node(op_node); sorted_comp_tasks->push_back(comp_task_node); } } } bool IsConnectedLbisAllSameNdSbp(const OpEdge* op_edge) { const OpNode* src_node = op_edge->src_node(); const OpNode* dst_node = op_edge->dst_node(); CHECK_GT(op_edge->lbis().size(), 0); HashSet predicators; for (const LogicalBlobId& lbi : op_edge->lbis()) { const cfg::NdSbp& src_nd_sbp = src_node->NdSbp4Lbi(lbi); const cfg::NdSbp& dst_nd_sbp = dst_node->NdSbp4Lbi(lbi); predicators.insert(src_nd_sbp == dst_nd_sbp); } CHECK_EQ(predicators.size(), 1); return *predicators.begin(); } BldSubTskGphMthd GetMthdForBldSubTskGph(const OpEdge* op_edge) { const OpNode* src_node = op_edge->src_node(); const OpNode* dst_node = op_edge->dst_node(); const ParallelDesc& src_pd = src_node->parallel_desc(); const ParallelDesc& dst_pd = dst_node->parallel_desc(); const OperatorConf& src_op_conf = src_node->op().op_conf(); const OperatorConf& dst_op_conf = dst_node->op().op_conf(); // WaitAndSendIds -> Reentrantlock if (src_op_conf.has_wait_and_send_ids_conf() && dst_op_conf.has_reentrant_lock_conf()) { CHECK_EQ(src_pd.parallel_num(), 1); CHECK_EQ(dst_pd.parallel_num(), 1); return &TaskGraph::BldSubTskGphByBoxing; } // *Tick -> *Tick if (IsTickOpConf(src_op_conf) || IsTickOpConf(dst_op_conf)) { if (src_op_conf.has_source_tick_conf()) { CHECK(dst_op_conf.has_tick_conf()); CHECK_EQ(src_pd.parallel_num(), 1); CHECK_EQ(dst_pd.parallel_num(), 1); return &TaskGraph::BldSubTskGphByBoxing; } else if (dst_op_conf.has_sink_tick_conf()) { CHECK(src_op_conf.has_tick_conf() || src_op_conf.has_sink_tick_conf()); CHECK_EQ(src_pd.parallel_num(), 1); CHECK_EQ(dst_pd.parallel_num(), 1); return &TaskGraph::BldSubTskGphByBoxing; } else if (IsSubsetTickOpConf(src_op_conf)) { return &TaskGraph::BldSubTskGphBySrcSubsetConnect; } else if (IsSubsetTickOpConf(dst_op_conf)) { return &TaskGraph::BldSubTskGphByDstSubsetConnect; } else if (IsTickOpConf(src_op_conf) && IsTickOpConf(dst_op_conf)) { if (src_pd.parallel_num() == dst_pd.parallel_num()) { return &TaskGraph::BldSubTskGphByOneToOne; } else { CHECK_EQ(src_pd.parallel_num(), 1); return &TaskGraph::BldSubTskGphByBroadcastToBroadcast; } } } std::shared_ptr src_comp_task(NewCompTaskNode4OpNode(src_node)); std::shared_ptr dst_comp_task(NewCompTaskNode4OpNode(dst_node)); // NOTE(chengcheng): MUST use TaskType instead of OpTypeCase because may // Multi-op correspoding to SAME TaskType such as: // DistributeConcatOpConf and DistributeAddOpConf -> TaskType::kDistributeConcat // DistributeSplitOpConf and DistributeCloneOpConf -> TaskType::kDistributeSplit // * -> DistributeConcat if (dst_comp_task->GetTaskType() == TaskType::kDistributeConcat) { return &TaskGraph::BldSubTskGphByPartialInLbiConnect; } // DistributeSplit -> * if (src_comp_task->GetTaskType() == TaskType::kDistributeSplit) { return &TaskGraph::BldSubTskGphByPartialOutLbiConnect; } // NormalForward -> DecodeH2D if (src_comp_task->GetTaskType() == TaskType::kNormalForward && dst_comp_task->GetTaskType() == TaskType::kDecodeH2D) { return &TaskGraph::BldSubTskGphNormalForwardToDecodeH2D; } if (src_pd.parallel_num() == 1 && dst_pd.parallel_num() == 1) { return &TaskGraph::BldSubTskGphByOneToOne; } // one to one if (src_pd.parallel_num() == dst_pd.parallel_num() && *src_pd.hierarchy() == *dst_pd.hierarchy() && IsConnectedLbisAllSameNdSbp(op_edge)) { return &TaskGraph::BldSubTskGphByOneToOne; } return &TaskGraph::BldSubTskGphByBoxing; } void ForEachOpGraphNecessaryCtrlEdge( const OpGraph* op_graph, const std::function& Handler) { auto IsOpGraphDataReachable = op_graph->MakePredicatorIsReachable(); op_graph->ForEachNode([&](OpNode* dst) { for (const auto& ctrl_in_op_name : dst->op().op_conf().ctrl_in_op_name()) { const OpNode* src = op_graph->OpNode4OpName(ctrl_in_op_name); CHECK(!IsOpGraphDataReachable(dst, src)); if (!IsOpGraphDataReachable(src, dst)) { CHECK_EQ(dst->parallel_desc().parallel_num(), src->parallel_desc().parallel_num()); const Shape* src_time_shape = CHECK_JUST(src->op().GetOpTimeShape()).get(); const Shape* dst_time_shape = CHECK_JUST(dst->op().GetInputBlobFastestTimeShape()).get(); if (dst_time_shape == nullptr) { dst_time_shape = CHECK_JUST(dst->op().GetOpTimeShape()).get(); } CHECK_EQ(src_time_shape->elem_cnt(), dst_time_shape->elem_cnt()); Handler(src, dst); } } }); } } // namespace TaskGraph::TaskGraph() { OpGraph* op_graph = Global::Get(); sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this)); boxing_logger_ = CreateBoxingLogger(); hierarchical_sub_tsk_gph_builder_.reset(new DispatchHierarchicalSubTskGphBuilder()); HashMap> op_node2sorted_comp_tasks; op_graph->ForEachNode([&](const OpNode* op_node) { std::vector* sorted_comp_tasks = &(op_node2sorted_comp_tasks[op_node]); GenSortedCompTaskNodes(op_node, sorted_comp_tasks); for (CompTaskNode* comp_task : *sorted_comp_tasks) { AddAllocatedNode(comp_task); } }); op_graph->ForEachEdge([&](const OpEdge* op_edge) { BldSubTskGphMthd method = GetMthdForBldSubTskGph(op_edge); (this->*method)(op_edge, op_node2sorted_comp_tasks.at(op_edge->src_node()), op_node2sorted_comp_tasks.at(op_edge->dst_node())); }); ForEachOpGraphNecessaryCtrlEdge(op_graph, [&](const OpNode* src, const OpNode* dst) { const auto& src_task_nodes = op_node2sorted_comp_tasks.at(src); const auto& dst_task_nodes = op_node2sorted_comp_tasks.at(dst); if (src->op().op_conf().has_src_subset_tick_conf()) { UNIMPLEMENTED(); } else if (dst->op().op_conf().has_dst_subset_tick_conf()) { UNIMPLEMENTED(); } else { ConnectCtrlEdges(src_task_nodes, dst_task_nodes); } }); SetOrderInGraphForEachNode(); if (Global::Get()->enable_debug_mode()) { ToDotWithAutoFilePath(); } } TaskGraph::~TaskGraph() = default; TaskEdge* TaskGraph::NewTaskEdgeWithLbi(const LogicalBlobId& lbi) { TaskEdge* edge = NewEdge(); edge->AddLbi(lbi); return edge; } TaskEdge* TaskGraph::NewTaskEdgeWithLbis(const std::vector& lbis) { TaskEdge* edge = NewEdge(); edge->AddLbis(lbis); return edge; } TaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi, const MemZoneId& dst_mem_zone_id) { const auto& src_mem_zone_id = src_node->MemZoneId121(); const ProxyKey key(src_node, lbi, dst_mem_zone_id); auto it = proxy2node.find(key); if (it != proxy2node.cend()) { // hit cache return it->second; } else { if (src_mem_zone_id == dst_mem_zone_id) { // in the same memory zone proxy2node[key] = src_node; return src_node; } else if (dst_mem_zone_id.device_type() == DeviceType::kCPU) { if (src_mem_zone_id.node_index() == dst_mem_zone_id.node_index()) { // on the same node, not on the same device // src must be not on the cpu mem zone, copy d2h first CHECK(IsMemcpyDtoHSupported(src_mem_zone_id.device_id().device_type())); CopyHdTaskNode* copy_task = NewNode(); copy_task->Init(CopyHdOpConf::D2H, src_mem_zone_id.device_id(), lbi); Connect(src_node, NewTaskEdgeWithLbi(lbi), copy_task); proxy2node[key] = copy_task; return copy_task; } else { // not on the same node, need CopyCommNet from src to dst // build src cpu proxy first TaskNode* proxy_on_src_host = GetProxyNode(src_node, lbi, GetNodeCPUMemZoneId(src_mem_zone_id.node_index())); CopyCommNetTaskNode* copy_comm_net_task = NewNode(); copy_comm_net_task->Init(dst_mem_zone_id.node_index(), lbi); Connect(proxy_on_src_host, NewTaskEdgeWithLbi(lbi), copy_comm_net_task); proxy2node[key] = copy_comm_net_task; return copy_comm_net_task; } } else { TaskNode* proxy_on_dst_host = GetProxyNode(src_node, lbi, GetNodeCPUMemZoneId(dst_mem_zone_id.node_index())); CHECK(IsMemcpyHtoDSupported(dst_mem_zone_id.device_id().device_type())); CopyHdTaskNode* copy_task = NewNode(); copy_task->Init(CopyHdOpConf::H2D, dst_mem_zone_id.device_id(), lbi); Connect(proxy_on_dst_host, NewTaskEdgeWithLbi(lbi), copy_task); proxy2node[key] = copy_task; return copy_task; } } return nullptr; } TaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi, const ParallelDesc& dst_parallel_desc, int64_t dst_parallel_id) { const int64_t dst_machine_id = CHECK_JUST(dst_parallel_desc.MachineId4ParallelId(dst_parallel_id)); const int64_t dev_id = CHECK_JUST(dst_parallel_desc.DeviceId4ParallelId(dst_parallel_id)); DeviceType device_type = dst_parallel_desc.device_type(); auto device_index = (device_type == DeviceType::kCPU ? DeviceId::kCPUDeviceIndex : static_cast(dev_id)); MemZoneId mem_zone_id{static_cast(dst_machine_id), device_type, device_index}; return GetProxyNode(src_node, lbi, mem_zone_id); } void TaskGraph::ConnectCtrlEdges(const std::vector& src_task_nodes, const std::vector& dst_task_nodes) { CHECK_EQ(src_task_nodes.size(), dst_task_nodes.size()); FOR_RANGE(int32_t, i, 0, src_task_nodes.size()) { std::string regst_desc_name; src_task_nodes.at(i)->BuildCtrlRegstDesc(dst_task_nodes.at(i), ®st_desc_name); TaskEdge* edge = NewEdge(); Connect(src_task_nodes.at(i), edge, dst_task_nodes.at(i)); src_task_nodes.at(i)->BindEdgeWithProducedRegst(edge, regst_desc_name); } } void TaskGraph::AddCtrlEdgeBetweenSrcDstTickAndInputOutputInSameRank() { if (!CHECK_JUST(GlobalMultiClientEnv())) { return; } HashMap rank_id2src_tick; HashMap rank_id2dst_tick; HashMap> rank_id2input_output_nodes; ForEachNode([&](TaskNode* node) { if (node->GetTaskType() == TaskType::kSrcSubsetTick) { CHECK(rank_id2src_tick.emplace(node->machine_id(), node).second); } else if (node->GetTaskType() == TaskType::kDstSubsetTick) { CHECK(rank_id2dst_tick.emplace(node->machine_id(), node).second); } else if (node->GetTaskType() == TaskType::kNormalForward) { auto* forward_node = reinterpret_cast(node); CHECK(forward_node); if (forward_node->op()->op_conf().has_input_conf() || forward_node->op()->op_conf().has_output_conf()) { CHECK(rank_id2input_output_nodes[node->machine_id()].insert(node).second); } } }); auto AddCtrlEdge = [&](TaskNode* src, TaskNode* dst) { std::string ctrl_regst_name; src->BuildCtrlRegstDesc(dst, &ctrl_regst_name); TaskEdge* edge = NewEdge(); Connect(src, edge, dst); src->BindEdgeWithProducedRegst(edge, ctrl_regst_name); }; for (auto& pair : rank_id2src_tick) { int64_t rank_id = pair.first; TaskNode* src = pair.second; for (TaskNode* io_task : rank_id2input_output_nodes[rank_id]) { AddCtrlEdge(src, io_task); } } for (auto& pair : rank_id2dst_tick) { int64_t rank_id = pair.first; TaskNode* dst = pair.second; for (TaskNode* io_task : rank_id2input_output_nodes[rank_id]) { AddCtrlEdge(io_task, dst); } } } void TaskGraph::RemoveEmptyRegsts() { ForEachNode([&](TaskNode* node) { node->EraseUninitializedShapeProducedBlob(); }); ForEachNode([&](TaskNode* node) { node->EraseZeroSizeConsumedRegst(); }); ForEachNode([&](TaskNode* node) { node->EraseZeroSizeProducedRegst(); }); ForEachNode([&](TaskNode* node) { node->UnbindBnWithEmptyRegst(); }); } void TaskGraph::MergeChainAndAddOrderingCtrlEdgeInSameChain() { MergeChain(); BuildCtrlRegstDescInSameChain(); } void TaskGraph::SetOrderInGraphForEachNode() { int64_t order_in_graph = 0; auto SetOrderInGraph = [&](TaskNode* task_node) { task_node->set_order_in_graph(order_in_graph); ordered_task_nodes_.emplace_back(task_node); ++order_in_graph; }; TopoForEachNode(SetOrderInGraph); } void TaskGraph::MergeChain() { int64_t chain_id = 0; for (auto* this_node : ordered_task_nodes_) { // skip if this node has been set in a chain. if (this_node->chain_id() != -1) { continue; } CHECK_EQ(this_node->chain_id(), -1); if (CanBeMergedInChain(this_node)) { TraverseConnectedSubGraphMergeInThisChain(this_node, chain_id); } else { this_node->set_chain_id(chain_id); } ++chain_id; } for (auto* node : ordered_task_nodes_) { CHECK_NE(node->chain_id(), -1); } } void TaskGraph::BuildCtrlRegstDescInSameChain() { HashMap chain_id2node; for (auto* node : ordered_task_nodes_) { if (IsConnectToTickOp(node)) { continue; } int64_t chain_id = node->chain_id(); auto iter = chain_id2node.find(chain_id); if (iter == chain_id2node.end()) { CHECK(chain_id2node.emplace(chain_id, node).second); } else { TaskNode* src_node = iter->second; TaskNode* dst_node = node; std::string ctrl_regst_name; bool build_ctrl_edge = src_node->BuildCtrlRegstDescIfNeed(dst_node, &ctrl_regst_name); if (build_ctrl_edge) { CHECK(!ctrl_regst_name.empty()); TaskEdge* edge = NewEdge(); Connect(src_node, edge, dst_node); src_node->BindEdgeWithProducedRegst(edge, ctrl_regst_name); } iter->second = dst_node; } } } void TaskGraph::GetInplaceOpBlobArgList( InplaceObasInfo* obas_info, const HashSet& dev_nodes, const std::function& TaskNode4OpName) const { auto AddMutableInplaceArgPair = [&](TaskNode* node, const std::string& ibn, const std::string& obn, const std::string& op_name) { if (IsInplaceAllowed(node, {ibn, obn}, TaskNode4OpName)) { auto* pair = obas_info->mut_inplace_oba_pairs.mutable_pair()->Add(); *pair->mutable_first() = GenOpBlobArg(op_name, ibn); *pair->mutable_second() = GenOpBlobArg(op_name, obn); } }; auto AddConstInplaceArgPair = [&](TaskNode* node, const std::string& ibn, const std::string& obn, const std::string& op_name) { if (IsInplaceAllowed(node, {ibn, obn}, TaskNode4OpName)) { auto* pair = obas_info->con_inplace_oba_pairs.mutable_pair()->Add(); *pair->mutable_first() = GenOpBlobArg(op_name, ibn); *pair->mutable_second() = GenOpBlobArg(op_name, obn); } }; for (TaskNode* task_node : dev_nodes) { if (task_node->exec_gph().node_num() != 1) { continue; } const auto& op = *task_node->exec_gph().SoleNode()->op(); for (const std::string& ibn : op.input_bns()) { if (op.InputBlobModifier4Ibn(ibn).is_mutable()) { CHECK(IsInplaceAllowed(task_node, {ibn}, TaskNode4OpName)); *obas_info->mut_in_obas.mutable_oba()->Add() = GenOpBlobArg(op.op_name(), ibn); } } for (const auto& pair : task_node->exec_gph().SoleNode()->mut_inplace_obn2ibn()) { AddMutableInplaceArgPair(task_node, pair.second, pair.first, op.op_name()); } for (const auto& pair : task_node->exec_gph().SoleNode()->con_inplace_obn2ibn()) { AddConstInplaceArgPair(task_node, pair.second, pair.first, op.op_name()); } } } void TaskGraph::GetSafeInplaceOpBlobArgList( InplaceObasInfo* safe_obas_info, const HashSet& dev_nodes, const std::function& IsOpNameDataOrCtrlReachable) const { auto TaskNode4SoleOpName = MakeGetterTaskNode4SoleOpName(dev_nodes); InplaceObasInfo obas_info; GetInplaceOpBlobArgList(&obas_info, dev_nodes, TaskNode4SoleOpName); auto Op4OpName = [&](const std::string& op_name) -> const Operator* { return TaskNode4SoleOpName(op_name)->exec_gph().SoleNode()->op().get(); }; auto IsLbiAllConsumersReachable = MakePredicatorIsLbiAllConsumersReachable(TaskNode4SoleOpName, IsOpNameDataOrCtrlReachable); InplaceLbiGraph origin_graph(obas_info, Op4OpName); InplaceLbiGraph safe_graph(*safe_obas_info, Op4OpName); origin_graph.ComputeSafeInplaceObns(safe_obas_info, IsLbiAllConsumersReachable); if (Global::Get()->enable_debug_mode()) { origin_graph.ToDotWithFilePath( JoinPath("dot", "InplaceLbiGraph", GlobalJobDesc().job_name() + "_origin.dot")); safe_graph.ToDotWithFilePath( JoinPath("dot", "InplaceLbiGraph", GlobalJobDesc().job_name() + "_safe.dot")); } } void TaskGraph::SetTaskRegstInplaceInfo(const InplaceObasInfo& obas_info, const HashSet& dev_nodes) const { auto TaskNode4SoleOpName = MakeGetterTaskNode4SoleOpName(dev_nodes); auto Op4OpName = [&](const std::string& op_name) -> const Operator* { return TaskNode4SoleOpName(op_name)->exec_gph().SoleNode()->op().get(); }; InplaceLbiGraph inplace_gph(obas_info, Op4OpName); inplace_gph.ForEachConnectedComponent([&](const HashSet& inplace_nodes) { for (const auto* inplace_node : inplace_nodes) { if (inplace_node->in_edges().empty()) { continue; } const auto* inplace_edge = inplace_node->SoleInEdge(); auto* exec_node = TaskNode4SoleOpName(inplace_edge->op().op_name())->exec_gph().SoleNode(); RegstDesc* in_regst = exec_node->RegstDesc4BnInOp(inplace_edge->ibn()); RegstDesc* out_regst = exec_node->RegstDesc4BnInOp(inplace_edge->obn()); out_regst->set_hint_inplace_consumed_regst_desc_id(in_regst->regst_desc_id()); } }); } void TaskGraph::ForEachGpuDeviceNodes( const std::function& dev_nodes)>& Handler) const { HashMap, HashSet> global_dev_phy_id2nodes; ForEachNode([&](TaskNode* task_node) { if (task_node->device_type() != DeviceType::kGPU) { return; } int64_t dev_phy_id = task_node->stream_id().device_id().device_index(); global_dev_phy_id2nodes[{task_node->machine_id(), dev_phy_id}].emplace(task_node); }); for (const auto& pair : global_dev_phy_id2nodes) { Handler(pair.second); } } void TaskGraph::EnableInplaceMemSharing( const std::function& IsOpNameDataOrCtrlReachable) { ForEachGpuDeviceNodes([&](const HashSet& dev_nodes) { InplaceObasInfo safe_inplace_obas_info; GetSafeInplaceOpBlobArgList(&safe_inplace_obas_info, dev_nodes, IsOpNameDataOrCtrlReachable); SetTaskRegstInplaceInfo(safe_inplace_obas_info, dev_nodes); }); } #define DEFINE_BLD_SUB_TASK_GRAPH_METHOD(method_name) \ void TaskGraph::method_name BLD_SUB_TSK_GPH_MTHD_ARGS() DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) { const OpNode* src_op_node = op_edge->src_node(); const OpNode* dst_op_node = op_edge->dst_node(); for (const LogicalBlobId& lbi : op_edge->lbis()) { std::vector in_nodes(sorted_src_comp_tasks.begin(), sorted_src_comp_tasks.end()); std::vector out_nodes; out_nodes.reserve(sorted_dst_comp_tasks.size()); std::vector> sorted_ctrl_tasks; const cfg::NdSbp& src_nd_sbp = src_op_node->NdSbp4Lbi(lbi); const cfg::NdSbp& dst_nd_sbp = dst_op_node->NdSbp4Lbi(lbi); const ParallelDesc& src_parallel_desc = src_op_node->parallel_desc(); const ParallelDesc& dst_parallel_desc = dst_op_node->parallel_desc(); const BlobDesc& blob_desc = src_op_node->LogicalBlobDesc4Lbi(lbi); auto status = CHECK_JUST(hierarchical_sub_tsk_gph_builder_->Build( sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, &sorted_ctrl_tasks, src_parallel_desc, dst_parallel_desc, lbi, blob_desc, src_nd_sbp, dst_nd_sbp, *(CHECK_JUST(src_op_node->op().GetOpTimeShape()).get()))); boxing_logger_->Log(*status, src_op_node->op().op_name(), dst_op_node->op().op_name(), src_parallel_desc, dst_parallel_desc, src_nd_sbp, dst_nd_sbp, lbi, blob_desc); CHECK_EQ(out_nodes.size(), sorted_dst_comp_tasks.size()); FOR_RANGE(size_t, i, 0, out_nodes.size()) { ConnectWithLbi(out_nodes.at(i), sorted_dst_comp_tasks.at(i), lbi); } if (!sorted_ctrl_tasks.empty()) { CHECK_EQ(sorted_ctrl_tasks.size(), sorted_dst_comp_tasks.size()); FOR_RANGE(size_t, i, 0, sorted_dst_comp_tasks.size()) { for (TaskNode* ctrl_node : sorted_ctrl_tasks.at(i)) { std::string regst_desc_name; ctrl_node->BuildCtrlRegstDesc(sorted_dst_comp_tasks.at(i), ®st_desc_name); TaskEdge* edge = NewEdge(); Connect(ctrl_node, edge, sorted_dst_comp_tasks.at(i)); ctrl_node->BindEdgeWithProducedRegst(edge, regst_desc_name); } } } } } DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByOneToOne) { CHECK_EQ(sorted_src_comp_tasks.size(), sorted_dst_comp_tasks.size()); FOR_RANGE(size_t, i, 0, sorted_src_comp_tasks.size()) { for (const LogicalBlobId& lbi : op_edge->lbis()) { BuildTaskPath(sorted_src_comp_tasks.at(i), sorted_dst_comp_tasks.at(i), lbi); } } } DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBroadcastToBroadcast) { for (CompTaskNode* dst_node : sorted_dst_comp_tasks) { CompTaskNode* nearest_src_node = SubTskGphBuilderUtil::FindNearestNode(sorted_src_comp_tasks, dst_node); CHECK_NOTNULL(nearest_src_node); for (const LogicalBlobId& lbi : op_edge->lbis()) { BuildTaskPath(nearest_src_node, dst_node, lbi); } } } DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialInLbiConnect) { const Operator& src_op = op_edge->src_node()->op(); const Operator& dst_op = op_edge->dst_node()->op(); HashSet lbis; for (const auto& obn : src_op.output_bns()) { lbis.insert(src_op.BnInOp2Lbi(obn)); } CHECK_EQ(sorted_src_comp_tasks.size(), 1); CHECK_EQ(dst_op.input_bns().size(), sorted_dst_comp_tasks.size()); FOR_RANGE(int, i, 0, sorted_dst_comp_tasks.size()) { const auto& lbi = dst_op.BnInOp2Lbi(dst_op.input_bns().Get(i)); if (lbis.find(lbi) != lbis.end()) { BuildTaskPath(sorted_src_comp_tasks.at(0), sorted_dst_comp_tasks.at(i), lbi); } } } DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialOutLbiConnect) { const Operator& src_op = op_edge->src_node()->op(); const Operator& dst_op = op_edge->dst_node()->op(); HashSet lbis; for (const auto& ibn : dst_op.input_bns()) { lbis.insert(dst_op.BnInOp2Lbi(ibn)); } CHECK_EQ(sorted_dst_comp_tasks.size(), 1); CHECK_EQ(src_op.output_bns().size(), sorted_src_comp_tasks.size()); FOR_RANGE(int, i, 0, sorted_src_comp_tasks.size()) { const auto& lbi = src_op.BnInOp2Lbi(src_op.output_bns().Get(i)); if (lbis.find(lbi) != lbis.end()) { BuildTaskPath(sorted_src_comp_tasks.at(i), sorted_dst_comp_tasks.at(0), lbi); } } } DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphBySrcSubsetConnect) { std::function(int64_t mchn_id, int64_t thrd_id)> TaskNode4MachineId7ThrdId; CHECK_JUST( MakeGetterTaskNode4MachineId7ThrdId(sorted_src_comp_tasks, &TaskNode4MachineId7ThrdId)); for (CompTaskNode* dst_task_node : sorted_dst_comp_tasks) { CompTaskNode* src_task_node = CHECK_JUST( TaskNode4MachineId7ThrdId(dst_task_node->machine_id(), dst_task_node->thrd_id())); Connect(src_task_node, NewTaskEdgeWithLbis(op_edge->lbis()), dst_task_node); } } DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByDstSubsetConnect) { std::function(int64_t mchn_id, int64_t thrd_id)> TaskNode4MachineId7ThrdId; CHECK_JUST( MakeGetterTaskNode4MachineId7ThrdId(sorted_dst_comp_tasks, &TaskNode4MachineId7ThrdId)); for (CompTaskNode* src_task_node : sorted_src_comp_tasks) { CompTaskNode* dst_task_node = CHECK_JUST( TaskNode4MachineId7ThrdId(src_task_node->machine_id(), src_task_node->thrd_id())); Connect(src_task_node, NewTaskEdgeWithLbis(op_edge->lbis()), dst_task_node); } } DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphNormalForwardToDecodeH2D) { CHECK_EQ(sorted_src_comp_tasks.size(), sorted_dst_comp_tasks.size()); FOR_RANGE(size_t, i, 0, sorted_src_comp_tasks.size()) { CompTaskNode* src = sorted_src_comp_tasks.at(i); CompTaskNode* dst = sorted_dst_comp_tasks.at(i); for (const LogicalBlobId& lbi : op_edge->lbis()) { ConnectWithLbi(src, dst, lbi); } } } void TaskGraph::ConnectWithLbi(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi) { if (src_node == dst_node) { return; } for (TaskEdge* out_edge : src_node->out_edges()) { TaskNode* out_node = out_edge->dst_node(); if (out_node == dst_node) { out_edge->AddLbi(lbi); return; } } TaskEdge* connected_edge = NewEdge(); connected_edge->AddLbi(lbi); Connect(src_node, connected_edge, dst_node); } void TaskGraph::BuildTaskPath(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi) { TaskNode* proxy_node = GetProxyNode(src_node, lbi, dst_node->MemZoneId121()); ConnectWithLbi(proxy_node, dst_node, lbi); } } // namespace oneflow