From 0252bca822ecf72fcc606fb859d96a648024b5b2 Mon Sep 17 00:00:00 2001 From: Jinhui Yuan Date: Wed, 29 Aug 2018 18:55:27 +0800 Subject: [PATCH] sketch of merge reduce project (#1159) * sketch of merge reduce project * add reduce_concat, reduce_split in logical graph (#1160) * add reduce_concat, reduce_split in logical graph * init ReduceTaskNodes in CollectReduceTaskNodes * add CompTaskNode for ReduceConcat & ReduceSplit * set ReduceConcat/Split color index * copy blob desc from ReduceConcat in to ReduceSplit out * refine CollectReduceTaskNodes * SetMemSharing for ReduceConcat, ReduceSplit regst * complete ReduceConcat & ReduceSplit op * fill ReduceConcat & ReduceSplit kernel * simplify ReduceConcatCompActor * make ReduceScatter & ReduceSplit as input-wise actor * reduce_scatter & reduce_split use is_inplace * use ByteSizeOfBlobBody for reduce related packed blob * Fix dev merge reduce (#1168) * check concat and split occur simultaneously * fix ReduceScatter & ReduceSplit as Inputwise actor * ReduceConcat & ReduceSplit works * fix single gpu issue * Refactor reduce (#1170) * backup, not complete yet * remove reduce_id * rm useless comment * add reduce_graph (#1169) * add reduce_graph * fix iter * add IsLogicalNodeMergeable and fix bug * remove needless constructor calls * node VisualStr may conflict, using node_id_str instead * reduce group works (#1171) * refine * sort nodes in topo (#1172) * add reduce_group_size in job_conf, fix 121 config of ReduceSplit and MdUpdt * resolve code review issues (variable names) * refine variable names * Dev merge reduce rename reduce group (#1174) * ReduceGraph=>ChainLogicalGraph * rename Group=>Chain * reformat * use pointer instead of reference for mutable argument * format change * worker node only pull sub_plan (#1176) * log compile time * use c++11 member initialization syntax * FixPackedBlobDescOfProducedRegst for ReduceSplit * Dev merge reduce refine chain logical graph (#1177) * remove IsMerageable * split TryMergeOneChain and rename to TryMergeTwoChains * reformat * resolve review issues Former-commit-id: 3aa79c70e5a788c0baa94b26288011680f162d2d --- oneflow/core/actor/naive_actor.cpp | 2 - .../actor/reduce_concat_compute_actor.cpp | 17 ++ .../core/actor/reduce_concat_compute_actor.h | 23 +++ .../actor/reduce_scatter_compute_actor.cpp | 16 ++ .../core/actor/reduce_scatter_compute_actor.h | 23 +++ .../core/actor/reduce_split_compute_actor.cpp | 16 ++ .../core/actor/reduce_split_compute_actor.h | 23 +++ oneflow/core/comm_network/comm_network.cpp | 33 +--- oneflow/core/graph/chain_logical_graph.cpp | 186 ++++++++++++++++++ oneflow/core/graph/chain_logical_graph.h | 60 ++++++ oneflow/core/graph/graph.h | 8 +- oneflow/core/graph/logical_graph.cpp | 84 +++++++- oneflow/core/graph/logical_graph.h | 11 +- oneflow/core/graph/logical_node.cpp | 22 +++ oneflow/core/graph/logical_node.h | 2 + .../normal_backward_compute_task_node.cpp | 3 +- .../normal_model_update_compute_task_node.cpp | 5 +- .../graph/reduce_concat_compute_task_node.cpp | 47 +++++ .../graph/reduce_concat_compute_task_node.h | 26 +++ .../graph/reduce_split_compute_task_node.cpp | 89 +++++++++ .../graph/reduce_split_compute_task_node.h | 28 +++ oneflow/core/graph/task_graph.cpp | 108 +++++++--- oneflow/core/graph/task_graph.h | 3 + oneflow/core/graph/task_node.cpp | 20 +- oneflow/core/graph/task_node.h | 5 +- oneflow/core/job/compiler.cpp | 42 ++++ oneflow/core/job/compiler.h | 1 + oneflow/core/job/job_conf.proto | 1 + oneflow/core/job/job_desc.h | 1 + oneflow/core/job/oneflow.cpp | 27 ++- oneflow/core/job/plan.proto | 9 + oneflow/core/job/task.proto | 8 +- oneflow/core/kernel/kernel.proto | 10 + oneflow/core/kernel/reduce_concat_kernel.cpp | 22 +++ oneflow/core/kernel/reduce_concat_kernel.h | 22 +++ oneflow/core/kernel/reduce_scatter_kernel.cpp | 6 +- oneflow/core/kernel/reduce_split_kernel.cpp | 22 +++ oneflow/core/kernel/reduce_split_kernel.h | 22 +++ oneflow/core/operator/op_conf.proto | 16 +- oneflow/core/operator/reduce_concat_op.cpp | 54 +++++ oneflow/core/operator/reduce_concat_op.h | 29 +++ oneflow/core/operator/reduce_split_op.cpp | 31 +++ oneflow/core/operator/reduce_split_op.h | 29 +++ 43 files changed, 1122 insertions(+), 90 deletions(-) create mode 100644 oneflow/core/actor/reduce_concat_compute_actor.cpp create mode 100644 oneflow/core/actor/reduce_concat_compute_actor.h create mode 100644 oneflow/core/actor/reduce_scatter_compute_actor.cpp create mode 100644 oneflow/core/actor/reduce_scatter_compute_actor.h create mode 100644 oneflow/core/actor/reduce_split_compute_actor.cpp create mode 100644 oneflow/core/actor/reduce_split_compute_actor.h create mode 100644 oneflow/core/graph/chain_logical_graph.cpp create mode 100644 oneflow/core/graph/chain_logical_graph.h create mode 100644 oneflow/core/graph/reduce_concat_compute_task_node.cpp create mode 100644 oneflow/core/graph/reduce_concat_compute_task_node.h create mode 100644 oneflow/core/graph/reduce_split_compute_task_node.cpp create mode 100644 oneflow/core/graph/reduce_split_compute_task_node.h create mode 100644 oneflow/core/kernel/reduce_concat_kernel.cpp create mode 100644 oneflow/core/kernel/reduce_concat_kernel.h create mode 100644 oneflow/core/kernel/reduce_split_kernel.cpp create mode 100644 oneflow/core/kernel/reduce_split_kernel.h create mode 100644 oneflow/core/operator/reduce_concat_op.cpp create mode 100644 oneflow/core/operator/reduce_concat_op.h create mode 100644 oneflow/core/operator/reduce_split_op.cpp create mode 100644 oneflow/core/operator/reduce_split_op.h diff --git a/oneflow/core/actor/naive_actor.cpp b/oneflow/core/actor/naive_actor.cpp index 8d2e218327..148d85c2bb 100644 --- a/oneflow/core/actor/naive_actor.cpp +++ b/oneflow/core/actor/naive_actor.cpp @@ -11,6 +11,4 @@ void NaiveActor::Act() { }); } -REGISTER_ACTOR(TaskType::kReduceScatter, NaiveActor); - } // namespace oneflow diff --git a/oneflow/core/actor/reduce_concat_compute_actor.cpp b/oneflow/core/actor/reduce_concat_compute_actor.cpp new file mode 100644 index 0000000000..8cedf4dcf3 --- /dev/null +++ b/oneflow/core/actor/reduce_concat_compute_actor.cpp @@ -0,0 +1,17 @@ +#include "oneflow/core/actor/reduce_concat_compute_actor.h" + +namespace oneflow { + +void ReduceConcatCompActor::VirtualCompActorInit(const TaskProto& proto) { + InputWiseCompActor::Init(proto); +} + +void ReduceConcatCompActor::SetKernelCtxOther(void** other) { + int64_t in_bn_id = InBnId4RegstDescId(cur_processed_regst_desc_id()); + other_val_ = std::make_pair(in_bn_id, EnableInplace()); + *other = static_cast(&other_val_); +} + +REGISTER_ACTOR(TaskType::kReduceConcat, ReduceConcatCompActor); + +} // namespace oneflow diff --git a/oneflow/core/actor/reduce_concat_compute_actor.h b/oneflow/core/actor/reduce_concat_compute_actor.h new file mode 100644 index 0000000000..e482fce7d0 --- /dev/null +++ b/oneflow/core/actor/reduce_concat_compute_actor.h @@ -0,0 +1,23 @@ +#ifndef ONEFLOW_CORE_ACTOR_REDUCE_CONCAT_COMPUTE_ACTOR_H_ +#define ONEFLOW_CORE_ACTOR_REDUCE_CONCAT_COMPUTE_ACTOR_H_ + +#include "oneflow/core/actor/input_wise_compute_actor.h" + +namespace oneflow { + +class ReduceConcatCompActor final : public InputWiseCompActor { + public: + OF_DISALLOW_COPY_AND_MOVE(ReduceConcatCompActor); + ReduceConcatCompActor() = default; + ~ReduceConcatCompActor() = default; + + private: + void VirtualCompActorInit(const TaskProto& proto) override; + void SetKernelCtxOther(void** other) override; + + std::pair other_val_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_ACTOR_REDUCE_CONCAT_COMPUTE_ACTOR_H_ diff --git a/oneflow/core/actor/reduce_scatter_compute_actor.cpp b/oneflow/core/actor/reduce_scatter_compute_actor.cpp new file mode 100644 index 0000000000..5b5bb5adc9 --- /dev/null +++ b/oneflow/core/actor/reduce_scatter_compute_actor.cpp @@ -0,0 +1,16 @@ +#include "oneflow/core/actor/reduce_scatter_compute_actor.h" + +namespace oneflow { + +void ReduceScatterCompActor::VirtualCompActorInit(const TaskProto& proto) { + InputWiseCompActor::Init(proto); +} + +void ReduceScatterCompActor::SetKernelCtxOther(void** other) { + other_val_ = EnableInplace(); + *other = static_cast(&other_val_); +} + +REGISTER_ACTOR(TaskType::kReduceScatter, ReduceScatterCompActor); + +} // namespace oneflow diff --git a/oneflow/core/actor/reduce_scatter_compute_actor.h b/oneflow/core/actor/reduce_scatter_compute_actor.h new file mode 100644 index 0000000000..ef6e7a05f1 --- /dev/null +++ b/oneflow/core/actor/reduce_scatter_compute_actor.h @@ -0,0 +1,23 @@ +#ifndef ONEFLOW_CORE_ACTOR_REDUCE_SCATTER_COMPUTE_ACTOR_H_ +#define ONEFLOW_CORE_ACTOR_REDUCE_SCATTER_COMPUTE_ACTOR_H_ + +#include "oneflow/core/actor/input_wise_compute_actor.h" + +namespace oneflow { + +class ReduceScatterCompActor final : public InputWiseCompActor { + public: + OF_DISALLOW_COPY_AND_MOVE(ReduceScatterCompActor); + ReduceScatterCompActor() = default; + ~ReduceScatterCompActor() = default; + + private: + void VirtualCompActorInit(const TaskProto& proto) override; + void SetKernelCtxOther(void** other) override; + + bool other_val_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_ACTOR_REDUCE_SCATTER_COMPUTE_ACTOR_H_ diff --git a/oneflow/core/actor/reduce_split_compute_actor.cpp b/oneflow/core/actor/reduce_split_compute_actor.cpp new file mode 100644 index 0000000000..d5e58b32fa --- /dev/null +++ b/oneflow/core/actor/reduce_split_compute_actor.cpp @@ -0,0 +1,16 @@ +#include "oneflow/core/actor/reduce_split_compute_actor.h" + +namespace oneflow { + +void ReduceSplitCompActor::VirtualCompActorInit(const TaskProto& proto) { + InputWiseCompActor::Init(proto); +} + +void ReduceSplitCompActor::SetKernelCtxOther(void** other) { + other_val_ = EnableInplace(); + *other = static_cast(&other_val_); +} + +REGISTER_ACTOR(TaskType::kReduceSplit, ReduceSplitCompActor); + +} // namespace oneflow diff --git a/oneflow/core/actor/reduce_split_compute_actor.h b/oneflow/core/actor/reduce_split_compute_actor.h new file mode 100644 index 0000000000..7cbdc0c582 --- /dev/null +++ b/oneflow/core/actor/reduce_split_compute_actor.h @@ -0,0 +1,23 @@ +#ifndef ONEFLOW_CORE_ACTOR_REDUCE_SPLIT_COMPUTE_ACTOR_H_ +#define ONEFLOW_CORE_ACTOR_REDUCE_SPLIT_COMPUTE_ACTOR_H_ + +#include "oneflow/core/actor/input_wise_compute_actor.h" + +namespace oneflow { + +class ReduceSplitCompActor final : public InputWiseCompActor { + public: + OF_DISALLOW_COPY_AND_MOVE(ReduceSplitCompActor); + ReduceSplitCompActor() = default; + ~ReduceSplitCompActor() = default; + + private: + void VirtualCompActorInit(const TaskProto& proto) override; + void SetKernelCtxOther(void** other) override; + + bool other_val_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_ACTOR_REDUCE_SPLIT_COMPUTE_ACTOR_H_ diff --git a/oneflow/core/comm_network/comm_network.cpp b/oneflow/core/comm_network/comm_network.cpp index 791c3bc743..9ed0fa803b 100644 --- a/oneflow/core/comm_network/comm_network.cpp +++ b/oneflow/core/comm_network/comm_network.cpp @@ -69,34 +69,13 @@ void CommNet::AddWorkToStream(void* actor_read_id, const std::function& } CommNet::CommNet(const Plan& plan) { - HashMap rid2mid; - HashMap tid2mid; int64_t this_machine_id = Global::Get()->this_machine_id(); - - for (const TaskProto& task_proto : plan.task()) { - for (const auto& regst_desc_it : task_proto.produced_regst_desc()) { - rid2mid.emplace(regst_desc_it.second.regst_desc_id(), task_proto.machine_id()); - } - CHECK(tid2mid.emplace(task_proto.task_id(), task_proto.machine_id()).second); - } - for (const TaskProto& task_proto : plan.task()) { - if (task_proto.machine_id() != this_machine_id) { continue; } - for (const auto& regst_desc_set_it : task_proto.consumed_regst_desc_id()) { - for (int64_t regst_desc_id : regst_desc_set_it.second.regst_desc_id()) { - auto rid2mid_it = rid2mid.find(regst_desc_id); - CHECK(rid2mid_it != rid2mid.end()); - peer_machine_id_.insert(rid2mid_it->second); - } - } - for (const auto& regst_desc_it : task_proto.produced_regst_desc()) { - for (int64_t consumer_task_id : regst_desc_it.second.consumer_task_id()) { - auto tid2mid_it = tid2mid.find(consumer_task_id); - CHECK(tid2mid_it != tid2mid.end()); - peer_machine_id_.insert(tid2mid_it->second); - } - } - } - peer_machine_id_.erase(this_machine_id); + HashMap net_topo; + net_topo = PbMap2HashMap(plan.net_topo().peer_machine_ids()); + auto machine_ids_it = net_topo.find(this_machine_id); + CHECK(machine_ids_it != net_topo.end()); + std::vector peer_machine_ids = PbRf2StdVec(machine_ids_it->second.machine_id()); + peer_machine_id_.insert(peer_machine_ids.begin(), peer_machine_ids.end()); ready_cb_poller_ = std::thread([this]() { std::function cb; diff --git a/oneflow/core/graph/chain_logical_graph.cpp b/oneflow/core/graph/chain_logical_graph.cpp new file mode 100644 index 0000000000..16fcfeab68 --- /dev/null +++ b/oneflow/core/graph/chain_logical_graph.cpp @@ -0,0 +1,186 @@ +#include "oneflow/core/operator/fully_connected_op.h" +#include "oneflow/core/graph/chain_logical_graph.h" +#include "oneflow/core/graph/logical_graph.h" + +namespace oneflow { + +struct ChainLogicalGraph::Chain { + std::vector nodes; + HashSet ancestors; + HashSet ancestors_and_this; + HashSet descendants; + HashSet descendants_and_this; + bool is_mergeable; + + bool IsParallelDescEqual(const Chain& rhs) const { + CHECK_GT(nodes.size(), 0); + CHECK_GT(rhs.nodes.size(), 0); + return nodes.front()->parallel_desc()->Equal(rhs.nodes.front()->parallel_desc().get()); + } +}; + +ChainLogicalGraph::ChainLogicalGraph(const LogicalGraph& logical_graph) { + std::list chain_list; + HashMap::iterator> logical2chain_it; + HashMap logical2order_in_topo; + + InitChains(logical_graph, &chain_list, &logical2chain_it, &logical2order_in_topo); + MergeChains(&chain_list, &logical2chain_it); + SortNodesInChains(&chain_list, logical2order_in_topo); + BuildGraph(logical_graph, &chain_list); + ToDotWithAutoFilePath(); +} + +void ChainLogicalGraph::InitChains( + const LogicalGraph& logical_graph, std::list* chain_list, + HashMap::iterator>* logical2chain_it, + HashMap* logical2order_in_topo) { + logical_graph.ForEachNode([&](const LogicalNode* node) { + chain_list->emplace_back(); + logical2chain_it->insert({node, --chain_list->end()}); + Chain& chain = chain_list->back(); + chain.nodes = {node}; + chain.is_mergeable = IsLogicalNodeMergeable(node); + size_t order_in_topo = logical2order_in_topo->size(); + logical2order_in_topo->emplace(node, order_in_topo); + }); + + logical_graph.TopoForEachNode([&](const LogicalNode* node) { + auto cur_chain = logical2chain_it->at(node); + for (const LogicalEdge* edge : node->in_edges()) { + LogicalNode* pred_node = edge->src_node(); + auto pred_chain = logical2chain_it->at(pred_node); + cur_chain->ancestors.insert(pred_chain->ancestors.begin(), pred_chain->ancestors.end()); + cur_chain->ancestors.insert(pred_node); + } + cur_chain->ancestors_and_this.insert(cur_chain->ancestors.begin(), cur_chain->ancestors.end()); + cur_chain->ancestors_and_this.insert(cur_chain->nodes.begin(), cur_chain->nodes.end()); + }); + + logical_graph.ReverseTopoForEachNode([&](const LogicalNode* node) { + auto cur_chain = logical2chain_it->at(node); + for (const LogicalEdge* edge : node->out_edges()) { + LogicalNode* succ_node = edge->dst_node(); + auto succ_chain = logical2chain_it->at(succ_node); + cur_chain->descendants.insert(succ_chain->descendants.begin(), succ_chain->descendants.end()); + cur_chain->descendants.insert(succ_node); + } + cur_chain->descendants_and_this.insert(cur_chain->descendants.begin(), + cur_chain->descendants.end()); + cur_chain->descendants_and_this.insert(cur_chain->nodes.begin(), cur_chain->nodes.end()); + }); +} + +void ChainLogicalGraph::MergeChains( + std::list* chain_list, + HashMap::iterator>* logical2chain_it) { + while (chain_list->size() > 1 && TryMergeTwoChains(chain_list, logical2chain_it)) {}; +} + +bool ChainLogicalGraph::TryMergeTwoChains( + std::list* chain_list, + HashMap::iterator>* logical2chain_it) { + return TryMergeTwoParallelChains(chain_list, logical2chain_it) + || TryMergeTwoConnectedChains(chain_list, logical2chain_it); +} + +bool ChainLogicalGraph::TryMergeTwoParallelChains( + std::list* chain_list, + HashMap::iterator>* logical2chain_it) { + for (auto lhs = chain_list->begin(); lhs != chain_list->end(); ++lhs) { + if (!lhs->is_mergeable) { continue; } + for (auto rhs = lhs; rhs != chain_list->end(); ++rhs) { + if (lhs == rhs) { continue; } + if (!rhs->is_mergeable) { continue; } + if (!lhs->IsParallelDescEqual(*rhs)) { continue; } + if (lhs->ancestors != rhs->ancestors || lhs->descendants != rhs->descendants) { continue; } + for (const LogicalNode* node : rhs->nodes) { + lhs->nodes.push_back(node); + lhs->ancestors_and_this.insert(node); + lhs->descendants_and_this.insert(node); + logical2chain_it->at(node) = lhs; + } + chain_list->erase(rhs); + return true; + } + } + return false; +} + +bool ChainLogicalGraph::TryMergeTwoConnectedChains( + std::list* chain_list, + HashMap::iterator>* logical2chain_it) { + for (auto succ_chain_it = chain_list->begin(); succ_chain_it != chain_list->end(); + ++succ_chain_it) { + if (!succ_chain_it->is_mergeable) { continue; } + for (const LogicalNode* node_in_succ : succ_chain_it->nodes) { + for (const LogicalEdge* in_edge : node_in_succ->in_edges()) { + auto pred_chain_it = logical2chain_it->at(in_edge->src_node()); + if (pred_chain_it == succ_chain_it) { continue; } + if (!pred_chain_it->is_mergeable) { continue; } + if (!pred_chain_it->IsParallelDescEqual(*succ_chain_it)) { continue; } + if (pred_chain_it->ancestors_and_this != succ_chain_it->ancestors + || pred_chain_it->descendants != succ_chain_it->descendants_and_this) { + continue; + } + for (const LogicalNode* node : succ_chain_it->nodes) { + pred_chain_it->nodes.push_back(node); + pred_chain_it->ancestors_and_this.insert(node); + pred_chain_it->descendants.erase(node); + logical2chain_it->at(node) = pred_chain_it; + } + chain_list->erase(succ_chain_it); + return true; + } + } + } + return false; +} + +void ChainLogicalGraph::SortNodesInChains( + std::list* chain_list, + const HashMap& logical2order_in_topo) { + for (Chain& chain : *chain_list) { + std::sort(chain.nodes.begin(), chain.nodes.end(), + [&](const LogicalNode* a, const LogicalNode* b) { + return logical2order_in_topo.at(a) < logical2order_in_topo.at(b); + }); + } +} + +void ChainLogicalGraph::BuildGraph(const LogicalGraph& logical_graph, + std::list* chain_list) { + HashMap logical_node2chain_logical_node; + + for (const Chain& chain : *chain_list) { + ChainLogicalNode* chain_logical_node = NewNode(); + chain_logical_node->mut_logical_nodes() = chain.nodes; + for (const LogicalNode* node : chain.nodes) { + CHECK(logical_node2chain_logical_node.emplace(node, chain_logical_node).second); + } + } + + std::unordered_set> pred_succ_pairs; + logical_graph.ForEachEdge([&](const LogicalEdge* edge) { + pred_succ_pairs.emplace(logical_node2chain_logical_node.at(edge->src_node()), + logical_node2chain_logical_node.at(edge->dst_node())); + }); + + for (auto& pair : pred_succ_pairs) { + if (pair.first == pair.second) { continue; } + ChainLogicalEdge* edge = NewEdge(); + Connect(pair.first, edge, pair.second); + } +} + +bool ChainLogicalGraph::IsLogicalNodeMergeable(const LogicalNode* logical_node) const { + if (logical_node->parallel_desc()->policy() != kDataParallel) { return false; } + if (!dynamic_cast(logical_node)) { return false; } + for (const std::shared_ptr& op : logical_node->op_vec()) { + if (dynamic_cast(op.get())) { return false; } + if (op->IsRecurrentOp()) { return false; } + } + return true; +} + +} // namespace oneflow diff --git a/oneflow/core/graph/chain_logical_graph.h b/oneflow/core/graph/chain_logical_graph.h new file mode 100644 index 0000000000..bb95bcc433 --- /dev/null +++ b/oneflow/core/graph/chain_logical_graph.h @@ -0,0 +1,60 @@ +#ifndef ONEFLOW_CORE_GRAPH_CHAIN_LOGICAL_GRAPH_H_ +#define ONEFLOW_CORE_GRAPH_CHAIN_LOGICAL_GRAPH_H_ + +#include "oneflow/core/graph/graph.h" +#include "oneflow/core/graph/logical_graph.h" + +namespace oneflow { + +class ChainLogicalEdge; + +class ChainLogicalNode final : public Node { + public: + OF_DISALLOW_COPY_AND_MOVE(ChainLogicalNode); + ChainLogicalNode() = default; + ~ChainLogicalNode() override = default; + + const std::vector& logical_nodes() const { return logical_nodes_; } + std::vector& mut_logical_nodes() { return logical_nodes_; } + + private: + std::vector logical_nodes_; +}; + +class ChainLogicalEdge final : public Edge { + public: + OF_DISALLOW_COPY_AND_MOVE(ChainLogicalEdge); + ChainLogicalEdge() = default; + ~ChainLogicalEdge() override = default; +}; + +class ChainLogicalGraph final : public Graph { + public: + OF_DISALLOW_COPY_AND_MOVE(ChainLogicalGraph); + explicit ChainLogicalGraph(const LogicalGraph& logical_graph); + ~ChainLogicalGraph() override = default; + + private: + struct Chain; + void InitChains(const LogicalGraph& logical_graph, std::list* chain_list, + HashMap::iterator>* logical2chain_it, + HashMap* logical2order_in_topo); + void MergeChains(std::list* chain_list, + HashMap::iterator>* logical2chain_it); + bool TryMergeTwoChains(std::list* chain_list, + HashMap::iterator>* logical2chain_it); + bool TryMergeTwoParallelChains( + std::list* chain_list, + HashMap::iterator>* logical2chain_it); + bool TryMergeTwoConnectedChains( + std::list* chain_list, + HashMap::iterator>* logical2chain_it); + void SortNodesInChains(std::list* chain_list, + const HashMap& logical2order_in_topo); + void BuildGraph(const LogicalGraph& logical_graph, std::list* chain_list); + bool IsLogicalNodeMergeable(const LogicalNode* logical_node) const; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_GRAPH_CHAIN_LOGICAL_GRAPH_H_ diff --git a/oneflow/core/graph/graph.h b/oneflow/core/graph/graph.h index 4ea28f3214..f295dc1b99 100644 --- a/oneflow/core/graph/graph.h +++ b/oneflow/core/graph/graph.h @@ -148,10 +148,12 @@ template template void Graph::ToDotWithStream(StreamT& out_stream) { out_stream << "digraph {\n"; - this->ForEachNode([&](NodeType* node) { out_stream << "\"" << node->VisualStr() << "\"\n"; }); + this->ForEachNode([&](NodeType* node) { + out_stream << "\"" << node->node_id_str() << "\" [label=\"" << node->VisualStr() << "\"]\n"; + }); this->ForEachEdge([&](const EdgeType* edge) { - out_stream << "\"" << edge->src_node()->VisualStr() << "\" -> " - << "\"" << edge->dst_node()->VisualStr() << "\"" + out_stream << "\"" << edge->src_node()->node_id_str() << "\" -> " + << "\"" << edge->dst_node()->node_id_str() << "\"" << "[label=\"" << edge->VisualStr() << "\"];\n"; }); out_stream << "}\n"; diff --git a/oneflow/core/graph/logical_graph.cpp b/oneflow/core/graph/logical_graph.cpp index 9492e34d2f..b4404c066f 100644 --- a/oneflow/core/graph/logical_graph.cpp +++ b/oneflow/core/graph/logical_graph.cpp @@ -2,11 +2,14 @@ #include "oneflow/core/graph/task_graph.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/operator/op_conf.pb.h" +#include "oneflow/core/graph/chain_logical_graph.h" +#include "oneflow/core/common/balanced_splitter.h" namespace oneflow { LogicalGraph::LogicalGraph(bool is_train) { BuildFwStruct(); + if (is_train) { GroupNodesForReduceStruct(); } SetMainModelParallel(); if (is_train) { BuildBwStruct(); } MergeEdge(); @@ -20,6 +23,30 @@ LogicalGraph::LogicalGraph(bool is_train) { ToDotWithAutoFilePath(); } +void LogicalGraph::GroupNodesForReduceStruct() { + ChainLogicalGraph chain_logical_graph(*this); + std::vector> fw_node_groups; + chain_logical_graph.ForEachNode( + [&](ChainLogicalNode* node) { fw_node_groups.emplace_back(node->logical_nodes()); }); + for (auto& fw_node_group : fw_node_groups) { + if (fw_node_group.size() < Global::Get()->reduce_group_size()) { + fw_node_groups_.emplace_back(std::move(fw_node_group)); + } else { + int64_t fw_node_group_size = fw_node_group.size(); + int64_t seg_num = fw_node_group_size / Global::Get()->reduce_group_size() + 1; + BalancedSplitter bs(fw_node_group_size, seg_num); + FOR_RANGE(int64_t, idx, 0, seg_num) { + std::vector sub_fw_node_group; + Range range = bs.At(idx); + FOR_RANGE(int64_t, nid, range.begin(), range.end()) { + sub_fw_node_group.emplace_back(fw_node_group[nid]); + } + fw_node_groups_.emplace_back(std::move(sub_fw_node_group)); + } + } + } +} + template void LogicalGraph::ForEachLogicalNode(std::function func) { std::vector valid_nodes; @@ -352,6 +379,7 @@ void LogicalGraph::BuildAccuracyPrintStruct() { void LogicalGraph::BuildModelStruct(bool is_train) { HashMap first_shared2mdupdt; + HashMap fw_node2reduce_ctx; ForEachLogicalNode([&](ForwardLogicalNode* fw_logical) { if (Global::Get()->enable_write_snapshot() && fw_logical->HasOpWithForwardModelBlob()) { @@ -397,17 +425,69 @@ void LogicalGraph::BuildModelStruct(bool is_train) { } if (md_diff_acc_logical->parallel_desc()->parallel_num() > 1 && md_diff_acc_logical->parallel_desc()->policy() == kDataParallel) { - BuildReduceStruct(md_diff_acc_logical, md_updt_logical); + ReduceCtx reduce_ctx; + reduce_ctx.fw_logicals.emplace_back(fw_logical); + reduce_ctx.bw_logicals.emplace_back(bw_logical); + reduce_ctx.md_diff_acc_logicals.emplace_back(md_diff_acc_logical); + reduce_ctx.md_updt_logicals.emplace_back(md_updt_logical); + CHECK(fw_node2reduce_ctx.emplace(fw_logical, reduce_ctx).second); } else { Connect(md_diff_acc_logical, NewEdge(), md_updt_logical); } } } }); + for (auto& fw_node_group : fw_node_groups_) { + ReduceCtx group_reduce_ctx; + for (auto& fw_node : fw_node_group) { + auto reduce_ctx_it = fw_node2reduce_ctx.find(fw_node); + if (reduce_ctx_it != fw_node2reduce_ctx.end()) { + auto& reduce_ctx = reduce_ctx_it->second; + group_reduce_ctx.fw_logicals.emplace_back(reduce_ctx.fw_logicals.at(0)); + group_reduce_ctx.bw_logicals.emplace_back(reduce_ctx.bw_logicals.at(0)); + group_reduce_ctx.md_diff_acc_logicals.emplace_back(reduce_ctx.md_diff_acc_logicals.at(0)); + group_reduce_ctx.md_updt_logicals.emplace_back(reduce_ctx.md_updt_logicals.at(0)); + } + } + BuildReduceStruct(group_reduce_ctx); + } SetupNormalMdUpdtOp(); } -void LogicalGraph::BuildReduceStruct(LogicalNode* src, LogicalNode* dst) { +void LogicalGraph::BuildReduceStruct(const ReduceCtx& reduce_ctx) { + if (reduce_ctx.fw_logicals.size() > 1) { + std::shared_ptr src_pd = reduce_ctx.fw_logicals[0]->parallel_desc(); + + OperatorConf reduce_concat_op_conf; + reduce_concat_op_conf.set_name("reduce_concat_" + NewUniqueId()); + reduce_concat_op_conf.set_device_type(src_pd->device_type()); + reduce_concat_op_conf.mutable_reduce_concat_conf()->set_in_num(reduce_ctx.fw_logicals.size()); + LogicalNode* reduce_concat_node = NewNode(); + reduce_concat_node->mut_op_vec() = {ConstructOp(reduce_concat_op_conf)}; + reduce_concat_node->mut_parallel_desc() = src_pd; + + OperatorConf reduce_split_op_conf; + reduce_split_op_conf.set_name("reduce_split_" + NewUniqueId()); + reduce_split_op_conf.set_device_type(src_pd->device_type()); + reduce_split_op_conf.mutable_reduce_split_conf()->set_out_num(reduce_ctx.fw_logicals.size()); + LogicalNode* reduce_split_node = NewNode(); + reduce_split_node->mut_op_vec() = {ConstructOp(reduce_split_op_conf)}; + reduce_split_node->mut_parallel_desc() = src_pd; + + for (auto& md_diff_acc_node : reduce_ctx.md_diff_acc_logicals) { + Connect(md_diff_acc_node, NewEdge(), reduce_concat_node); + } + for (auto& md_updt_node : reduce_ctx.md_updt_logicals) { + Connect(reduce_split_node, NewEdge(), md_updt_node); + } + AddReduceScatterAddGatherNodes(reduce_concat_node, reduce_split_node); + } else if (reduce_ctx.fw_logicals.size() == 1) { + AddReduceScatterAddGatherNodes(reduce_ctx.md_diff_acc_logicals.at(0), + reduce_ctx.md_updt_logicals.at(0)); + } +} + +void LogicalGraph::AddReduceScatterAddGatherNodes(LogicalNode* src, LogicalNode* dst) { std::shared_ptr src_pd = src->parallel_desc(); std::shared_ptr dst_pd = dst->parallel_desc(); CHECK_EQ(src_pd->parallel_num(), dst_pd->parallel_num()); diff --git a/oneflow/core/graph/logical_graph.h b/oneflow/core/graph/logical_graph.h index a032c3132a..be32122fc6 100644 --- a/oneflow/core/graph/logical_graph.h +++ b/oneflow/core/graph/logical_graph.h @@ -27,8 +27,15 @@ class LogicalGraph final : public Graph { LogicalBlobId lbi; std::vector edges; }; + struct ReduceCtx { + std::vector fw_logicals; + std::vector bw_logicals; + std::vector md_diff_acc_logicals; + std::vector md_updt_logicals; + }; template void ForEachLogicalNode(std::function Handler); + void GroupNodesForReduceStruct(); void BuildFwStruct(); void NaiveBuildFwStruct(HashMap>* op_name2nodes); @@ -46,7 +53,8 @@ class LogicalGraph final : public Graph { void BuildLossPrintStruct(); void BuildAccuracyPrintStruct(); void BuildModelStruct(bool is_train); - void BuildReduceStruct(LogicalNode* src, LogicalNode* dst); + void AddReduceScatterAddGatherNodes(LogicalNode* src, LogicalNode* dst); + void BuildReduceStruct(const ReduceCtx& reduce_ctx); void SetupNormalMdUpdtOp(); MdSaveLogicalNode* BuildMdSaveStruct(const ForwardLogicalNode* fw_logical, LogicalNode* need_save_logical); @@ -58,6 +66,7 @@ class LogicalGraph final : public Graph { int64_t total_mbn_num_; + std::vector> fw_node_groups_; HashMap edge2ibn_; HashMap edge2obn_; }; diff --git a/oneflow/core/graph/logical_node.cpp b/oneflow/core/graph/logical_node.cpp index 5aca756e17..9152d93c20 100644 --- a/oneflow/core/graph/logical_node.cpp +++ b/oneflow/core/graph/logical_node.cpp @@ -14,6 +14,8 @@ #include "oneflow/core/graph/reduce_local_add_compute_task_node.h" #include "oneflow/core/graph/reduce_global_add_compute_task_node.h" #include "oneflow/core/graph/reduce_gather_compute_task_node.h" +#include "oneflow/core/graph/reduce_concat_compute_task_node.h" +#include "oneflow/core/graph/reduce_split_compute_task_node.h" #include "oneflow/core/graph/accuracy_compute_task_node.h" #include "oneflow/core/graph/accuracy_accumulate_compute_task_node.h" #include "oneflow/core/graph/accuracy_print_compute_task_node.h" @@ -335,6 +337,18 @@ REGISTER_BLD_SUB_TSK_GPH_MTHD("MdDiffAcc" REGISTER_BLD_SUB_TSK_GPH_MTHD("NormalBackward" "NormalMdUpdt", BldSubTskGphToNormalMdUpdt); +REGISTER_BLD_SUB_TSK_GPH_MTHD("NormalBackward" + "ReduceConcat", + &TaskGraph::BldSubTskGphByOneToOne); +REGISTER_BLD_SUB_TSK_GPH_MTHD("MdDiffAcc" + "ReduceConcat", + &TaskGraph::BldSubTskGphByOneToOne); +REGISTER_BLD_SUB_TSK_GPH_MTHD("ReduceConcat" + "ReduceScatter", + &TaskGraph::BldSubTskGphByOneToOne); +REGISTER_BLD_SUB_TSK_GPH_MTHD("NormalBackward" + "ReduceScatter", + &TaskGraph::BldSubTskGphByOneToOne); REGISTER_BLD_SUB_TSK_GPH_MTHD("MdDiffAcc" "ReduceScatter", &TaskGraph::BldSubTskGphByOneToOne); @@ -353,6 +367,12 @@ REGISTER_BLD_SUB_TSK_GPH_MTHD("ReduceGlobalAdd" REGISTER_BLD_SUB_TSK_GPH_MTHD("ReduceGather" "NormalMdUpdt", &TaskGraph::BldSubTskGphByOneToOne); +REGISTER_BLD_SUB_TSK_GPH_MTHD("ReduceSplit" + "NormalMdUpdt", + &TaskGraph::BldSubTskGphByOneToOne); +REGISTER_BLD_SUB_TSK_GPH_MTHD("ReduceGather" + "ReduceSplit", + &TaskGraph::BldSubTskGphByOneToOne); BldBoxingOpConfMthd GetMthdForBldBoxingOpConf(const LogicalNode* src, const LogicalNode* dst) { std::string k = ConcatTypeName(src, dst); @@ -395,10 +415,12 @@ REGISTER_BLD_BOXING_OP_CONF_MTHD("NormalBackward" OF_PP_MAKE_TUPLE_SEQ(MdSave, kMdSaveArea) \ OF_PP_MAKE_TUPLE_SEQ(MdDiffAcc, kDataBackwardArea) \ OF_PP_MAKE_TUPLE_SEQ(Print, kPrintArea) \ + OF_PP_MAKE_TUPLE_SEQ(ReduceConcat, kMdUpdtArea) \ OF_PP_MAKE_TUPLE_SEQ(ReduceScatter, kMdUpdtArea) \ OF_PP_MAKE_TUPLE_SEQ(ReduceLocalAdd, kMdUpdtArea) \ OF_PP_MAKE_TUPLE_SEQ(ReduceGlobalAdd, kMdUpdtArea) \ OF_PP_MAKE_TUPLE_SEQ(ReduceGather, kMdUpdtArea) \ + OF_PP_MAKE_TUPLE_SEQ(ReduceSplit, kMdUpdtArea) \ OF_PP_MAKE_TUPLE_SEQ(Accuracy, kDataForwardArea) \ OF_PP_MAKE_TUPLE_SEQ(AccuracyAcc, kDataForwardArea) \ OF_PP_MAKE_TUPLE_SEQ(AccuracyPrint, kPrintArea) diff --git a/oneflow/core/graph/logical_node.h b/oneflow/core/graph/logical_node.h index a804be9338..0563847123 100644 --- a/oneflow/core/graph/logical_node.h +++ b/oneflow/core/graph/logical_node.h @@ -203,6 +203,8 @@ DECLARE_NAIVE_LOGICAL_NODE(ReduceScatterLogicalNode); DECLARE_NAIVE_LOGICAL_NODE(ReduceLocalAddLogicalNode); DECLARE_NAIVE_LOGICAL_NODE(ReduceGlobalAddLogicalNode); DECLARE_NAIVE_LOGICAL_NODE(ReduceGatherLogicalNode); +DECLARE_NAIVE_LOGICAL_NODE(ReduceConcatLogicalNode); +DECLARE_NAIVE_LOGICAL_NODE(ReduceSplitLogicalNode); } // namespace oneflow diff --git a/oneflow/core/graph/normal_backward_compute_task_node.cpp b/oneflow/core/graph/normal_backward_compute_task_node.cpp index dff319e627..fba95c942a 100644 --- a/oneflow/core/graph/normal_backward_compute_task_node.cpp +++ b/oneflow/core/graph/normal_backward_compute_task_node.cpp @@ -11,7 +11,8 @@ void NormalBackwardCompTaskNode::ProduceAllRegstsAndBindEdges() { for (TaskEdge* edge : out_edges()) { const LogicalNode* succ_logical = GetOneSuccLogicalNodeOnEdge(edge); if (succ_logical->TypeName() == "MdDiffAcc" || succ_logical->TypeName() == "NormalMdUpdt" - || succ_logical->TypeName() == "ReduceScatter") { + || succ_logical->TypeName() == "ReduceScatter" + || succ_logical->TypeName() == "ReduceConcat") { edge->AddRegst("model_diff", ProduceRegst("model_diff", true)); } else { BindEdgeWithProducedRegst(edge, "in_diff"); diff --git a/oneflow/core/graph/normal_model_update_compute_task_node.cpp b/oneflow/core/graph/normal_model_update_compute_task_node.cpp index caf9c49e49..02ec24fa9d 100644 --- a/oneflow/core/graph/normal_model_update_compute_task_node.cpp +++ b/oneflow/core/graph/normal_model_update_compute_task_node.cpp @@ -43,7 +43,10 @@ void NormalMdUpdtCompTaskNode::ProduceAllRegstsAndBindEdges() { void NormalMdUpdtCompTaskNode::ConsumeAllRegsts() { if (!IsTrainable()) { return; } for (TaskEdge* edge : in_edges()) { - ConsumeRegst("model_diff_acc_" + NewUniqueId(), edge->GetSoleRegst()); + auto regst_descs = edge->GetRegsts(); + for (auto& regst_desc : regst_descs) { + ConsumeRegst("model_diff_acc_" + NewUniqueId(), regst_desc); + } } } diff --git a/oneflow/core/graph/reduce_concat_compute_task_node.cpp b/oneflow/core/graph/reduce_concat_compute_task_node.cpp new file mode 100644 index 0000000000..41fdd291fa --- /dev/null +++ b/oneflow/core/graph/reduce_concat_compute_task_node.cpp @@ -0,0 +1,47 @@ +#include "oneflow/core/graph/reduce_concat_compute_task_node.h" +#include "oneflow/core/graph/logical_node.h" + +namespace oneflow { + +void ReduceConcatCompTaskNode::ProduceAllRegstsAndBindEdges() { + this->SoleOutEdge()->AddRegst("out", ProduceRegst("out", false, 1, 1)); +} + +void ReduceConcatCompTaskNode::ConsumeAllRegsts() { + struct EdgeInfo { + int64_t bw_node_order; + TaskEdge* edge; + }; + std::vector edge_infos; + for (TaskEdge* edge : in_edges()) { + TaskNode* src_node = edge->src_node(); + while (src_node->GetTaskType() != TaskType::kNormalBackward) { + src_node = src_node->SoleInEdge()->src_node(); + } + CompTaskNode* bw_node = dynamic_cast(src_node); + EdgeInfo edge_info{bw_node->order_in_graph(), edge}; + edge_infos.emplace_back(edge_info); + } + std::sort(edge_infos.begin(), edge_infos.end(), [](const EdgeInfo& lhs, const EdgeInfo& rhs) { + return lhs.bw_node_order < rhs.bw_node_order; + }); + FOR_RANGE(size_t, idx, 0, edge_infos.size()) { + ConsumeRegst("in_" + std::to_string(idx), edge_infos[idx].edge->GetSoleRegst()); + } +} + +void ReduceConcatCompTaskNode::BuildExecGphAndRegst() { + ExecNode* node = mut_exec_gph().NewNode(); + std::shared_ptr reduce_concat_op = this->logical_node()->SoleOp(); + node->mut_op() = reduce_concat_op; + FOR_RANGE(size_t, i, 0, reduce_concat_op->input_bns().size()) { + node->BindBnWithRegst(reduce_concat_op->input_bns().Get(i), + GetSoleConsumedRegst("in_" + std::to_string(i))); + } + std::shared_ptr out_regst = GetProducedRegst("out"); + out_regst->AddLbi(reduce_concat_op->BnInOp2Lbi(reduce_concat_op->SoleObn())); + node->BindBnWithRegst(reduce_concat_op->SoleObn(), out_regst); + node->InferBlobDescs(parallel_ctx()); +} + +} // namespace oneflow diff --git a/oneflow/core/graph/reduce_concat_compute_task_node.h b/oneflow/core/graph/reduce_concat_compute_task_node.h new file mode 100644 index 0000000000..a074d6debe --- /dev/null +++ b/oneflow/core/graph/reduce_concat_compute_task_node.h @@ -0,0 +1,26 @@ +#ifndef ONEFLOW_CORE_GRAPH_REDUCE_CONCAT_COMPUTE_TASK_NODE_H_ +#define ONEFLOW_CORE_GRAPH_REDUCE_CONCAT_COMPUTE_TASK_NODE_H_ + +#include "oneflow/core/graph/compute_task_node.h" + +namespace oneflow { + +class ReduceConcatCompTaskNode final : public CompTaskNode { + public: + OF_DISALLOW_COPY_AND_MOVE(ReduceConcatCompTaskNode); + ReduceConcatCompTaskNode() = default; + ~ReduceConcatCompTaskNode() = default; + + void ProduceAllRegstsAndBindEdges() override; + void ConsumeAllRegsts() override; + + TaskType GetTaskType() const override { return TaskType::kReduceConcat; } + CudaWorkType GetCudaWorkType() const override { return CudaWorkType::kMix; } + + private: + void BuildExecGphAndRegst() override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_GRAPH_REDUCE_CONCAT_COMPUTE_TASK_NODE_H_ diff --git a/oneflow/core/graph/reduce_split_compute_task_node.cpp b/oneflow/core/graph/reduce_split_compute_task_node.cpp new file mode 100644 index 0000000000..18ea28b199 --- /dev/null +++ b/oneflow/core/graph/reduce_split_compute_task_node.cpp @@ -0,0 +1,89 @@ +#include "oneflow/core/graph/reduce_split_compute_task_node.h" +#include "oneflow/core/graph/logical_node.h" +#include "oneflow/core/register/register_desc.h" + +namespace oneflow { + +void ReduceSplitCompTaskNode::ProduceAllRegstsAndBindEdges() { + struct EdgeInfo { + int64_t bw_node_order; + TaskEdge* edge; + }; + std::vector edge_infos; + for (TaskEdge* edge : out_edges()) { + TaskNode* dst_node = edge->dst_node(); + CHECK(dst_node->GetTaskType() == TaskType::kNormalMdUpdt); + CompTaskNode* mdupdt_node = dynamic_cast(dst_node); + for (TaskEdge* mdupdt_edge : mdupdt_node->out_edges()) { + if (IsBackwardTaskType(mdupdt_edge->dst_node()->GetTaskType())) { + CompTaskNode* bw_node = dynamic_cast(mdupdt_edge->dst_node()); + // There may be multiple out_regsts on the same edge for shared_model app + EdgeInfo edge_info{bw_node->order_in_graph(), edge}; + edge_infos.emplace_back(edge_info); + } + } + } + std::sort(edge_infos.begin(), edge_infos.end(), [](const EdgeInfo& lhs, const EdgeInfo& rhs) { + return lhs.bw_node_order < rhs.bw_node_order; + }); + FOR_RANGE(size_t, idx, 0, edge_infos.size()) { + std::string out_regst_name = "out_" + std::to_string(idx); + std::shared_ptr out_regst = ProduceRegst(out_regst_name, false, 1, 1); + edge_infos[idx].edge->AddRegst(out_regst_name, out_regst); + } +} + +void ReduceSplitCompTaskNode::ConsumeAllRegsts() { + ConsumeRegst("in", this->SoleInEdge()->GetSoleRegst()); +} + +void ReduceSplitCompTaskNode::BuildExecGphAndRegst() { + ExecNode* node = mut_exec_gph().NewNode(); + std::shared_ptr reduce_split_op = this->logical_node()->SoleOp(); + node->mut_op() = reduce_split_op; + node->BindBnWithRegst(reduce_split_op->SoleIbn(), GetSoleConsumedRegst("in")); + + CompTaskNode* reduce_concat_node = FindPeerReduceConcatTaskNode(); + CHECK_EQ(reduce_concat_node->consumed_regsts().size(), produced_regsts().size()); + + FOR_RANGE(size_t, i, 0, reduce_split_op->output_bns().size()) { + std::shared_ptr out_regst = GetProducedRegst("out_" + std::to_string(i)); + CHECK(out_regst.get() != nullptr); + out_regst->CopyBlobDescFrom( + reduce_concat_node->GetSoleConsumedRegst("in_" + std::to_string(i)).get()); + node->BindBnWithRegst(reduce_split_op->output_bns().Get(i), out_regst); + } +} + +CompTaskNode* ReduceSplitCompTaskNode::FindPeerReduceConcatTaskNode() { + CompTaskNode* src_node = this; + bool found_direct_node = true; + + while (src_node->GetTaskType() != TaskType::kReduceConcat && found_direct_node) { + found_direct_node = false; + for (TaskEdge* edge : src_node->in_edges()) { + CompTaskNode* comp_task_node = dynamic_cast(edge->src_node()); + if (comp_task_node != nullptr) { + src_node = comp_task_node; + CHECK(src_node->GetTaskType() != TaskType::kNormalBackward); + found_direct_node = true; + break; + } + } + if (found_direct_node == false) { break; } + } + return src_node; +} + +void ReduceSplitCompTaskNode::FixPackedBlobDescOfProducedRegst() { + int64_t out_regst_num = produced_regsts().size(); + FOR_RANGE(int64_t, idx, 0, out_regst_num) { + std::shared_ptr out_regst = GetProducedRegst("out_" + std::to_string(idx)); + CHECK(out_regst->IsLocked()); + Shape& shape = out_regst->MutBlobDesc(GenPackedLbi())->mut_shape(); + shape = + Shape({static_cast(RoundUp(shape.elem_cnt(), parallel_ctx()->parallel_num()))}); + } +} + +} // namespace oneflow diff --git a/oneflow/core/graph/reduce_split_compute_task_node.h b/oneflow/core/graph/reduce_split_compute_task_node.h new file mode 100644 index 0000000000..dd9be5603a --- /dev/null +++ b/oneflow/core/graph/reduce_split_compute_task_node.h @@ -0,0 +1,28 @@ +#ifndef ONEFLOW_CORE_GRAPH_REDUCE_SPLIT_COMPUTE_TASK_NODE_H_ +#define ONEFLOW_CORE_GRAPH_REDUCE_SPLIT_COMPUTE_TASK_NODE_H_ + +#include "oneflow/core/graph/compute_task_node.h" + +namespace oneflow { + +class ReduceSplitCompTaskNode final : public CompTaskNode { + public: + OF_DISALLOW_COPY_AND_MOVE(ReduceSplitCompTaskNode); + ReduceSplitCompTaskNode() = default; + ~ReduceSplitCompTaskNode() = default; + + void ProduceAllRegstsAndBindEdges() override; + void ConsumeAllRegsts() override; + + TaskType GetTaskType() const override { return TaskType::kReduceSplit; } + CudaWorkType GetCudaWorkType() const override { return CudaWorkType::kMix; } + + private: + void BuildExecGphAndRegst() override; + void FixPackedBlobDescOfProducedRegst() override; + CompTaskNode* FindPeerReduceConcatTaskNode(); +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_GRAPH_REDUCE_SPLIT_COMPUTE_TASK_NODE_H_ diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 833e199f17..a8f5db123f 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -55,6 +55,7 @@ TaskGraph::TaskGraph(std::unique_ptr&& logical_gph) { &logical2sorted_out_box, MutBufTask, AllocateCpuThrdIdEvenly); SetAreaIdForNewNodes(logical_edge->src_node(), logical_edge->dst_node()); }); + MergeChainAndSetOrderInGraphForEachNode(); ToDotWithAutoFilePath(); } @@ -100,10 +101,7 @@ void TaskGraph::RemoveEmptyRegsts() { ForEachNode([&](TaskNode* node) { node->UnbindBnWithEmptyRegst(); }); } -void TaskGraph::AddOrderingCtrlEdgeInSameChain() { - MergeChainAndSetOrderInGraphForEachNode(); - BuildCtrlRegstDescInSameChain(); -} +void TaskGraph::AddOrderingCtrlEdgeInSameChain() { BuildCtrlRegstDescInSameChain(); } void TaskGraph::MergeChainAndSetOrderInGraphForEachNode() { ChainGraph chain_graph(*this); @@ -136,12 +134,12 @@ void TaskGraph::BuildCtrlRegstDescInSameChain() { } struct ReduceTaskNodes { - CompTaskNode* scatter; - CompTaskNode* local_add; - CompTaskNode* global_add; - CompTaskNode* gather; - - ReduceTaskNodes() : scatter(nullptr), local_add(nullptr), global_add(nullptr), gather(nullptr) {} + CompTaskNode* concat = nullptr; + CompTaskNode* scatter = nullptr; + CompTaskNode* local_add = nullptr; + CompTaskNode* global_add = nullptr; + CompTaskNode* gather = nullptr; + CompTaskNode* split = nullptr; }; void TaskGraph::EnableMemSharingInReduceStruct() { @@ -163,6 +161,20 @@ void TaskGraph::CollectReduceTaskNodes( return nullptr; }; + auto FindConcatAndScatter = [&](CompTaskNode* bw_or_md_diff_acc, + ReduceTaskNodes* reduce_task_nodes) { + CompTaskNode* concat_task_node = + FindSuccReduceTaskNode(bw_or_md_diff_acc, TaskType::kReduceConcat); + if (concat_task_node != nullptr) { + reduce_task_nodes->concat = concat_task_node; + reduce_task_nodes->scatter = + FindSuccReduceTaskNode(reduce_task_nodes->concat, TaskType::kReduceScatter); + } else { + reduce_task_nodes->scatter = + FindSuccReduceTaskNode(bw_or_md_diff_acc, TaskType::kReduceScatter); + } + }; + ForEachNode([&](TaskNode* task_node) { if (IsBackwardTaskType(task_node->GetTaskType()) == false) { return; } if (task_node->device_type() != DeviceType::kGPU) { return; } @@ -175,15 +187,16 @@ void TaskGraph::CollectReduceTaskNodes( } ReduceTaskNodes& reduce_task_nodes = (*bw2reduce_tasks)[bw_task_node]; - CompTaskNode* tmp_task_node = FindSuccReduceTaskNode(bw_task_node, TaskType::kMdDiffAcc); - if (tmp_task_node != nullptr) { - reduce_task_nodes.scatter = FindSuccReduceTaskNode(tmp_task_node, TaskType::kReduceScatter); + CompTaskNode* diff_acc_task_node = FindSuccReduceTaskNode(bw_task_node, TaskType::kMdDiffAcc); + if (diff_acc_task_node != nullptr) { + FindConcatAndScatter(diff_acc_task_node, &reduce_task_nodes); } else { - reduce_task_nodes.scatter = FindSuccReduceTaskNode(bw_task_node, TaskType::kReduceScatter); + FindConcatAndScatter(bw_task_node, &reduce_task_nodes); } - tmp_task_node = FindSuccReduceTaskNode(reduce_task_nodes.scatter, TaskType::kReduceLocalAdd); - if (tmp_task_node != nullptr) { - reduce_task_nodes.local_add = tmp_task_node; + CompTaskNode* local_add_task_node = + FindSuccReduceTaskNode(reduce_task_nodes.scatter, TaskType::kReduceLocalAdd); + if (local_add_task_node != nullptr) { + reduce_task_nodes.local_add = local_add_task_node; reduce_task_nodes.global_add = FindSuccReduceTaskNode(reduce_task_nodes.local_add, TaskType::kReduceGlobalAdd); } else { @@ -192,6 +205,9 @@ void TaskGraph::CollectReduceTaskNodes( } reduce_task_nodes.gather = FindSuccReduceTaskNode(reduce_task_nodes.global_add, TaskType::kReduceGather); + reduce_task_nodes.split = + FindSuccReduceTaskNode(reduce_task_nodes.gather, TaskType::kReduceSplit); + if (reduce_task_nodes.split == nullptr) { CHECK(reduce_task_nodes.concat == nullptr); } CHECK(reduce_task_nodes.scatter != nullptr); CHECK(reduce_task_nodes.global_add != nullptr); @@ -199,6 +215,41 @@ void TaskGraph::CollectReduceTaskNodes( }); } +void TaskGraph::EnableMemSharingInReduceConcatSplitIfNeed( + const ReduceTaskNodes& reduce_task_nodes, + std::function SetMemSharedField4Regst) { + if (reduce_task_nodes.concat == nullptr) { return; } + int32_t reduce_num = reduce_task_nodes.split->produced_regsts().size(); + + std::shared_ptr concat_out_regst = reduce_task_nodes.concat->GetProducedRegst("out"); + std::shared_ptr split_in_regst = reduce_task_nodes.split->GetSoleConsumedRegst("in"); + const BlobDesc* concat_out_packed = concat_out_regst->GetBlobDesc(GenPackedLbi()); + const BlobDesc* split_in_packed = split_in_regst->GetBlobDesc(GenPackedLbi()); + size_t concat_out_byte_size = RtBlobDesc(*concat_out_packed).ByteSizeOfBlobBody(); + size_t split_in_byte_size = RtBlobDesc(*split_in_packed).ByteSizeOfBlobBody(); + CHECK_EQ(concat_out_byte_size, split_in_byte_size); + SetMemSharedField4Regst(concat_out_regst.get(), 0); + SetMemSharedField4Regst(split_in_regst.get(), 0); + + int64_t offset = 0; + FOR_RANGE(int32_t, idx, 0, reduce_num) { + auto concat_in_regst = + reduce_task_nodes.concat->GetSoleConsumedRegst("in_" + std::to_string(idx)); + auto split_out_regst = reduce_task_nodes.split->GetProducedRegst("out_" + std::to_string(idx)); + SetMemSharedField4Regst(concat_in_regst.get(), offset); + SetMemSharedField4Regst(split_out_regst.get(), offset); + + // Check shape invariant + const BlobDesc* concat_in_packed = concat_in_regst->GetBlobDesc(GenPackedLbi()); + const BlobDesc* split_out_packed = split_out_regst->GetBlobDesc(GenPackedLbi()); + size_t concat_in_byte_size = RtBlobDesc(*concat_in_packed).ByteSizeOfBlobBody(); + size_t split_out_byte_size = RtBlobDesc(*split_out_packed).ByteSizeOfBlobBody(); + CHECK_EQ(concat_in_byte_size, split_out_byte_size); + + offset += concat_in_byte_size; + } +} + void TaskGraph::EnableMemSharingInOneReduce(const ReduceTaskNodes& reduce_task_nodes) { std::shared_ptr parallel_desc = reduce_task_nodes.scatter->logical_node()->parallel_desc(); @@ -211,12 +262,14 @@ void TaskGraph::EnableMemSharingInOneReduce(const ReduceTaskNodes& reduce_task_n int64_t mem_shared_id = Global::Get()->NewMemSharedId(); std::vector blob_index2offset(parallel_num, 0); - auto SetMemSharedField4Regst = [&](RegstDesc* regst, int64_t blob_id) { + auto SetMemSharedField4Regst = [&](RegstDesc* regst, int64_t offset) { regst->set_enable_mem_sharing(true); regst->set_mem_shared_id(mem_shared_id); - regst->set_mem_shared_offset(blob_index2offset.at(blob_id)); + regst->set_mem_shared_offset(offset); }; + EnableMemSharingInReduceConcatSplitIfNeed(reduce_task_nodes, SetMemSharedField4Regst); + // scatter { std::shared_ptr consumed_regst = @@ -233,7 +286,8 @@ void TaskGraph::EnableMemSharingInOneReduce(const ReduceTaskNodes& reduce_task_n for (int64_t i = 0; i < parallel_num; ++i) { SetMemSharedField4Regst( - reduce_task_nodes.scatter->GetProducedRegst("out_" + std::to_string(i)).get(), i); + reduce_task_nodes.scatter->GetProducedRegst("out_" + std::to_string(i)).get(), + blob_index2offset.at(i)); } } @@ -243,7 +297,7 @@ void TaskGraph::EnableMemSharingInOneReduce(const ReduceTaskNodes& reduce_task_n CHECK_EQ(mem_shared_id, consumed_regst->mem_shared_id()); CHECK_EQ(blob_index2offset.at(blob_id), consumed_regst->mem_shared_offset()); } else { - SetMemSharedField4Regst(consumed_regst, blob_id); + SetMemSharedField4Regst(consumed_regst, blob_index2offset.at(blob_id)); } }; @@ -267,7 +321,7 @@ void TaskGraph::EnableMemSharingInOneReduce(const ReduceTaskNodes& reduce_task_n for (int64_t i = 0; i < machine_num; ++i) { SetMemSharedField4Regst( reduce_task_nodes.local_add->GetProducedRegst("out_" + std::to_string(i)).get(), - i * dev_num_of_each_machine + dev_index_of_this_machine); + blob_index2offset.at(i * dev_num_of_each_machine + dev_index_of_this_machine)); } } @@ -278,18 +332,20 @@ void TaskGraph::EnableMemSharingInOneReduce(const ReduceTaskNodes& reduce_task_n for (const auto& kv : consumed_regsts) { int64_t in_parallel_id = oneflow_cast(kv.first.substr(3)); CHECK_EQ(1, kv.second.size()); - RegstDesc* consumed_regst = kv.second.front().get(); - SetOrCheck4ConsumedRegst(consumed_regst, in_parallel_id == parallel_id, in_parallel_id); + SetOrCheck4ConsumedRegst(kv.second.front().get(), in_parallel_id == parallel_id, + in_parallel_id); } }; // global add int consumed_regst_num = reduce_task_nodes.local_add ? machine_num : parallel_num; HandleMemSharedFieldOfConsumedRegsts(reduce_task_nodes.global_add, consumed_regst_num); - SetMemSharedField4Regst(reduce_task_nodes.global_add->GetProducedRegst("out").get(), parallel_id); + SetMemSharedField4Regst(reduce_task_nodes.global_add->GetProducedRegst("out").get(), + blob_index2offset.at(parallel_id)); // gather HandleMemSharedFieldOfConsumedRegsts(reduce_task_nodes.gather, parallel_num); - SetMemSharedField4Regst(reduce_task_nodes.gather->GetProducedRegst("out").get(), 0); + SetMemSharedField4Regst(reduce_task_nodes.gather->GetProducedRegst("out").get(), + blob_index2offset.at(0)); } void TaskGraph::AddCtrlEdge4MemSharingInOneReduce(const ReduceTaskNodes& reduce_task_nodes) { diff --git a/oneflow/core/graph/task_graph.h b/oneflow/core/graph/task_graph.h index 508ebe563d..e133b5acf8 100644 --- a/oneflow/core/graph/task_graph.h +++ b/oneflow/core/graph/task_graph.h @@ -87,6 +87,9 @@ class TaskGraph final : public Graph { template bool IsEndingTaskType(TaskType type); + void EnableMemSharingInReduceConcatSplitIfNeed( + const ReduceTaskNodes&, std::function SetMemSharedField4Regst); + void GeneratePersistenceThrdId( const std::vector>& persistence_nodes); diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp index c4817c53e9..82498ec807 100644 --- a/oneflow/core/graph/task_node.cpp +++ b/oneflow/core/graph/task_node.cpp @@ -331,6 +331,12 @@ std::shared_ptr TaskEdge::GetSoleRegst() const { return name_in_producer2regst_.begin()->second; } +std::vector> TaskEdge::GetRegsts() const { + std::vector> regst_descs; + for (auto& pair : name_in_producer2regst_) { regst_descs.emplace_back(pair.second); } + return regst_descs; +} + void TaskEdge::AddRegst(const std::string& name_in_producer, std::shared_ptr regst) { CHECK(name_in_producer2regst_.emplace(name_in_producer, regst).second); } @@ -356,12 +362,10 @@ RegstDescIdSet* FindOrCreateConsumedCtrlRegstDescIdSet(TaskProto* task_proto, } std::map task_type2color = { - {kInvalid, "0"}, {kNormalForward, "2"}, {kNormalBackward, "3"}, - {kRecordLoad, "1"}, {kDecode, "1"}, {kLoss, "4"}, - {kLossAcc, "5"}, {kLossPrint, "1"}, {kNormalMdUpdt, "6"}, - {kMdSave, "1"}, {kMdDiffAcc, "7"}, {kCopyHd, "8"}, - {kCopyCommNet, "9"}, {kBoxing, "10"}, {kPrint, "1"}, - {kReduceScatter, "2"}, {kReduceLocalAdd, "2"}, {kReduceGlobalAdd, "2"}, - {kReduceGather, "2"}, {kAccuracy, "4"}, {kAccuracyPrint, "1"}, - {kAccuracyAcc, "5"}}; + {kInvalid, "0"}, {kNormalForward, "2"}, {kNormalBackward, "3"}, {kRecordLoad, "1"}, + {kDecode, "1"}, {kLoss, "4"}, {kLossAcc, "5"}, {kLossPrint, "1"}, + {kNormalMdUpdt, "6"}, {kMdSave, "1"}, {kMdDiffAcc, "7"}, {kCopyHd, "8"}, + {kCopyCommNet, "9"}, {kBoxing, "10"}, {kPrint, "1"}, {kReduceConcat, "2"}, + {kReduceScatter, "2"}, {kReduceLocalAdd, "2"}, {kReduceGlobalAdd, "2"}, {kReduceGather, "2"}, + {kReduceSplit, "2"}, {kAccuracy, "4"}, {kAccuracyPrint, "1"}, {kAccuracyAcc, "5"}}; } // namespace oneflow diff --git a/oneflow/core/graph/task_node.h b/oneflow/core/graph/task_node.h index 06b397b205..a8e09757aa 100644 --- a/oneflow/core/graph/task_node.h +++ b/oneflow/core/graph/task_node.h @@ -36,10 +36,10 @@ class TaskNode : public Node { std::shared_ptr GetProducedRegst(const std::string& name); const std::list>& GetConsumedRegst(const std::string& name); std::shared_ptr GetSoleConsumedRegst(const std::string& name); - const HashMap>& produced_regsts() { + const HashMap>& produced_regsts() const { return produced_regsts_; } - const HashMap>>& consumed_regsts() { + const HashMap>>& consumed_regsts() const { return consumed_regsts_; } DeviceType device_type() const; @@ -126,6 +126,7 @@ class TaskEdge final : public Edge { std::shared_ptr GetRegst(const std::string& name_in_producer) const; std::shared_ptr GetSoleRegst() const; + std::vector> GetRegsts() const; void AddRegst(const std::string& name_in_producer, std::shared_ptr regst); diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index 6632fea6d6..6d775d12a9 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -48,6 +48,47 @@ Plan Compiler::Compile() { return plan; } +void Compiler::GenNetTopo(Plan* plan) { + HashMap rid2mid; + HashMap tid2mid; + std::map> net_topo; + + for (const TaskProto& task_proto : plan->task()) { + for (const auto& regst_desc_it : task_proto.produced_regst_desc()) { + rid2mid.emplace(regst_desc_it.second.regst_desc_id(), task_proto.machine_id()); + } + CHECK(tid2mid.emplace(task_proto.task_id(), task_proto.machine_id()).second); + } + + for (const TaskProto& task_proto : plan->task()) { + for (const auto& regst_desc_it : task_proto.produced_regst_desc()) { + int64_t rid = regst_desc_it.second.regst_desc_id(); + auto rid2mid_it = rid2mid.find(rid); + CHECK(rid2mid_it != rid2mid.end()); + int64_t producer_mid = rid2mid_it->second; + for (int64_t consumer_task_id : regst_desc_it.second.consumer_task_id()) { + auto tid2mid_it = tid2mid.find(consumer_task_id); + CHECK(tid2mid_it != tid2mid.end()); + int64_t consumer_mid = tid2mid_it->second; + net_topo[producer_mid].insert(consumer_mid); + net_topo[consumer_mid].insert(producer_mid); + } + } + } + + HashMap std_net_topo; + NetTopo& pb_net_topo = *(plan->mutable_net_topo()); + for (auto& pair : net_topo) { + int64_t src_mid = pair.first; + if (pair.second.count(src_mid)) { pair.second.erase(src_mid); } + std::vector peer_mids(pair.second.begin(), pair.second.end()); + MachineIds pb_mids; + *(pb_mids.mutable_machine_id()) = StdVec2PbRf(peer_mids); + CHECK(std_net_topo.emplace(src_mid, pb_mids).second); + } + *(pb_net_topo.mutable_peer_machine_ids()) = HashMap2PbMap(std_net_topo); +} + Plan Compiler::DoCompile() { #ifdef WITH_CUDA Global::New(); @@ -78,6 +119,7 @@ Plan Compiler::DoCompile() { task_node->ToProto(plan.mutable_task()->Add()); }); plan.set_total_mbn_num(total_mbn_num); + GenNetTopo(&plan); ToDotFile(plan, JoinPath(LogDir(), "/dot/plan.dot")); #ifdef WITH_CUDA Global::Delete(); diff --git a/oneflow/core/job/compiler.h b/oneflow/core/job/compiler.h index 7c5d1691f5..8b8fb51322 100644 --- a/oneflow/core/job/compiler.h +++ b/oneflow/core/job/compiler.h @@ -18,6 +18,7 @@ class Compiler final { private: Plan DoCompile(); + void GenNetTopo(Plan* plan); }; } // namespace oneflow diff --git a/oneflow/core/job/job_conf.proto b/oneflow/core/job/job_conf.proto index 3966f6510d..87324361d5 100644 --- a/oneflow/core/job/job_conf.proto +++ b/oneflow/core/job/job_conf.proto @@ -56,6 +56,7 @@ message OtherConf { optional uint64 rdma_mem_block_mbyte = 110 [default = 8]; optional uint64 rdma_recv_msg_buf_mbyte = 111 [default = 6]; + optional int64 reduce_group_size = 112 [default = 20]; optional bool collect_act_event = 113 [default = false]; optional bool enable_mem_sharing = 114 [default = true]; optional bool use_ordered_allreduce_in_mdupdt = 115 [default = false]; diff --git a/oneflow/core/job/job_desc.h b/oneflow/core/job/job_desc.h index 9ff7199089..6f6a69c19f 100644 --- a/oneflow/core/job/job_desc.h +++ b/oneflow/core/job/job_desc.h @@ -51,6 +51,7 @@ class JobDesc final { return IsTrain() && job_conf_.other().enable_write_snapshot(); } bool enable_blob_mem_sharing() const { return job_conf_.other().enable_blob_mem_sharing(); } + int64_t reduce_group_size() const { return job_conf_.other().reduce_group_size(); } // machine_name <-> machine_id int64_t MachineID4MachineName(const std::string& machine_name) const; diff --git a/oneflow/core/job/oneflow.cpp b/oneflow/core/job/oneflow.cpp index a17a6ebc11..b00ae6f206 100644 --- a/oneflow/core/job/oneflow.cpp +++ b/oneflow/core/job/oneflow.cpp @@ -81,6 +81,8 @@ std::string cluster_thrd_ids_key(const std::string& plan_name) { return plan_name + "_cluster_thrd_ids"; } +std::string net_topo_key(const std::string& plan_name) { return plan_name + "_net_topo"; } + std::string sub_plan_key(const std::string& plan_name, int64_t machine_id, int64_t thrd_id) { return plan_name + "_" + std::to_string(machine_id) + "_" + std::to_string(thrd_id); } @@ -114,6 +116,8 @@ void PushPlan(const std::string& plan_name, const Plan& plan) { } Global::Get()->PushKV(total_mbn_num_key(plan_name), std::to_string(plan.total_mbn_num())); + + Global::Get()->PushKV(net_topo_key(plan_name), plan.net_topo()); } void PullPlan(const std::string& plan_name, Plan* plan) { @@ -122,18 +126,21 @@ void PullPlan(const std::string& plan_name, Plan* plan) { PrintProtoToTextFile(cluster_thrd_ids, JoinPath(LogDir(), cluster_thrd_ids_key(plan_name))); HashMap machine_id2thrd_ids; machine_id2thrd_ids = PbMap2HashMap(cluster_thrd_ids.machine_id2thrd_ids()); - for (const auto& pair : machine_id2thrd_ids) { - int64_t machine_id = pair.first; - std::vector thrd_id_vec = PbRf2StdVec(pair.second.thrd_id()); - for (auto thrd_id : thrd_id_vec) { - SubPlan sub_plan; - Global::Get()->PullKV(sub_plan_key(plan_name, machine_id, thrd_id), &sub_plan); - plan->mutable_task()->MergeFrom(sub_plan.task()); - } + int64_t machine_id = Global::Get()->this_machine_id(); + auto thrd_ids_it = machine_id2thrd_ids.find(machine_id); + CHECK(thrd_ids_it != machine_id2thrd_ids.end()); + std::vector thrd_id_vec = PbRf2StdVec(thrd_ids_it->second.thrd_id()); + for (auto thrd_id : thrd_id_vec) { + SubPlan sub_plan; + Global::Get()->PullKV(sub_plan_key(plan_name, machine_id, thrd_id), &sub_plan); + plan->mutable_task()->MergeFrom(sub_plan.task()); } + NetTopo net_topo; std::string total_mbn_num; Global::Get()->PullKV(total_mbn_num_key(plan_name), &total_mbn_num); plan->set_total_mbn_num(oneflow_cast(total_mbn_num)); + Global::Get()->PullKV(net_topo_key(plan_name), &net_topo); + *(plan->mutable_net_topo()) = net_topo; } bool HasRelayPlacement() { @@ -194,9 +201,12 @@ Oneflow::Oneflow(const std::string& job_conf_filepath, const std::string& this_m Plan improved_plan; PushAvailableMemDescOfThisMachine(); AvailableMemDesc amd; + double start = GetCurTime(); if (machine_ctx->IsThisMachineMaster()) { + double start = GetCurTime(); naive_plan = Compiler().Compile(); + LOG(INFO) << "compile time: " << GetCurTime() - start; amd = PullAvailableMemDesc(); mem_shared_plan = Improver().ImproveMemSharedIdOnly(amd, naive_plan); PushPlan("naive_plan", naive_plan); @@ -208,6 +218,7 @@ Oneflow::Oneflow(const std::string& job_conf_filepath, const std::string& this_m OF_BARRIER(); PrintProtoToTextFile(naive_plan, JoinPath(LogDir(), "naive_plan")); PrintProtoToTextFile(mem_shared_plan, JoinPath(LogDir(), "mem_shared_plan")); + LOG(INFO) << "push_pull_plan:" << GetCurTime() - start; if (HasRelayPlacement()) { // Experiment Runtime { Runtime experiment_run(mem_shared_plan, true); } diff --git a/oneflow/core/job/plan.proto b/oneflow/core/job/plan.proto index 3b4bf4d435..3480b5ca2d 100644 --- a/oneflow/core/job/plan.proto +++ b/oneflow/core/job/plan.proto @@ -3,7 +3,16 @@ package oneflow; import "oneflow/core/job/task.proto"; +message MachineIds { + repeated int64 machine_id = 1; +} + +message NetTopo { + map peer_machine_ids = 1; +} + message Plan { repeated TaskProto task = 1; required int64 total_mbn_num = 2; + required NetTopo net_topo = 3; } diff --git a/oneflow/core/job/task.proto b/oneflow/core/job/task.proto index b5e30fd5c2..43f841cb43 100644 --- a/oneflow/core/job/task.proto +++ b/oneflow/core/job/task.proto @@ -27,9 +27,11 @@ enum TaskType { kReduceGlobalAdd = 18; kReduceGather = 19; kReduceScatter = 20; - kAccuracy = 21; - kAccuracyAcc = 22; - kAccuracyPrint = 23; + kReduceConcat = 21; + kReduceSplit = 22; + kAccuracy = 23; + kAccuracyAcc = 24; + kAccuracyPrint = 25; }; enum AreaType { diff --git a/oneflow/core/kernel/kernel.proto b/oneflow/core/kernel/kernel.proto index 43b180a5ff..ffa374e85d 100644 --- a/oneflow/core/kernel/kernel.proto +++ b/oneflow/core/kernel/kernel.proto @@ -104,6 +104,14 @@ message ReduceGatherKernelConf { repeated int64 data_offset = 1; } +message ReduceConcatKernelConf { + repeated int64 data_offset = 1; +} + +message ReduceSplitKernelConf { + repeated int64 data_offset = 1; +} + message AccuracyKernelConf { required DataType prediction_type = 1; required DataType label_type = 2; @@ -154,6 +162,8 @@ message KernelConf { NormalizationKernelConf normalization_conf = 250; LocalResponseNormalizationKernelConf local_response_normalization_conf = 300; ReduceGatherKernelConf reduce_gather_conf = 350; + ReduceConcatKernelConf reduce_concat_conf = 351; + ReduceSplitKernelConf reduce_split_conf = 352; AccuracyKernelConf accuracy_conf = 401; } } diff --git a/oneflow/core/kernel/reduce_concat_kernel.cpp b/oneflow/core/kernel/reduce_concat_kernel.cpp new file mode 100644 index 0000000000..f2884fe559 --- /dev/null +++ b/oneflow/core/kernel/reduce_concat_kernel.cpp @@ -0,0 +1,22 @@ +#include "oneflow/core/kernel/reduce_concat_kernel.h" + +namespace oneflow { + +template +void ReduceConcatKernel::ForwardDataContent( + const KernelCtx& ctx, std::function BnInOp2Blob) const { + const auto* other_val = static_cast*>(ctx.other); + int64_t in_bn_id = other_val->first; + bool is_inplace = other_val->second; + if (is_inplace) { return; } + Blob* out_blob = BnInOp2Blob("out"); + char* dst_cur_dptr = out_blob->mut_dptr(); + dst_cur_dptr += this->kernel_conf().reduce_concat_conf().data_offset().Get(in_bn_id); + Blob* in_blob = BnInOp2Blob(this->op_attribute().input_bns().Get(in_bn_id)); + size_t in_byte_size = in_blob->ByteSizeOfDataContentField(); + Memcpy(ctx.device_ctx, dst_cur_dptr, in_blob->dptr(), in_byte_size); +} + +ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kReduceConcatConf, ReduceConcatKernel); + +} // namespace oneflow diff --git a/oneflow/core/kernel/reduce_concat_kernel.h b/oneflow/core/kernel/reduce_concat_kernel.h new file mode 100644 index 0000000000..6c0b3d796c --- /dev/null +++ b/oneflow/core/kernel/reduce_concat_kernel.h @@ -0,0 +1,22 @@ +#ifndef ONEFLOW_CORE_KERNEL_REDUCE_CONCAT_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_REDUCE_CONCAT_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template +class ReduceConcatKernel final : public KernelIf { + public: + OF_DISALLOW_COPY_AND_MOVE(ReduceConcatKernel); + ReduceConcatKernel() = default; + ~ReduceConcatKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_REDUCE_CONCAT_KERNEL_H_ diff --git a/oneflow/core/kernel/reduce_scatter_kernel.cpp b/oneflow/core/kernel/reduce_scatter_kernel.cpp index 2fdded83b0..38f14c97bb 100644 --- a/oneflow/core/kernel/reduce_scatter_kernel.cpp +++ b/oneflow/core/kernel/reduce_scatter_kernel.cpp @@ -5,14 +5,14 @@ namespace oneflow { template void ReduceScatterKernel::ForwardDataContent( const KernelCtx& ctx, std::function BnInOp2Blob) const { - if (device_type == DeviceType::kGPU) { return; } + bool is_inplace = *static_cast(ctx.other); + if (is_inplace) { return; } const Blob* in_blob = BnInOp2Blob("in"); const char* src_cur_dptr = in_blob->dptr(); for (const std::string& obn : this->op_attribute().output_bns()) { Blob* out_blob = BnInOp2Blob(obn); size_t out_byte_size = out_blob->ByteSizeOfDataContentField(); - Memcpy(ctx.device_ctx, out_blob->mut_dptr(), src_cur_dptr, - out_byte_size); + Memcpy(ctx.device_ctx, out_blob->mut_dptr(), src_cur_dptr, out_byte_size); src_cur_dptr += out_byte_size; } } diff --git a/oneflow/core/kernel/reduce_split_kernel.cpp b/oneflow/core/kernel/reduce_split_kernel.cpp new file mode 100644 index 0000000000..847c7b2ead --- /dev/null +++ b/oneflow/core/kernel/reduce_split_kernel.cpp @@ -0,0 +1,22 @@ +#include "oneflow/core/kernel/reduce_split_kernel.h" + +namespace oneflow { + +template +void ReduceSplitKernel::ForwardDataContent( + const KernelCtx& ctx, std::function BnInOp2Blob) const { + bool is_inplace = *static_cast(ctx.other); + if (is_inplace) { return; } + const Blob* in_blob = BnInOp2Blob("in"); + const char* src_cur_dptr = in_blob->dptr(); + for (const std::string& obn : this->op_attribute().output_bns()) { + Blob* out_blob = BnInOp2Blob(obn); + size_t out_byte_size = out_blob->ByteSizeOfDataContentField(); + Memcpy(ctx.device_ctx, out_blob->mut_dptr(), src_cur_dptr, out_byte_size); + src_cur_dptr += out_byte_size; + } +} + +ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kReduceSplitConf, ReduceSplitKernel); + +} // namespace oneflow diff --git a/oneflow/core/kernel/reduce_split_kernel.h b/oneflow/core/kernel/reduce_split_kernel.h new file mode 100644 index 0000000000..cb1332ad09 --- /dev/null +++ b/oneflow/core/kernel/reduce_split_kernel.h @@ -0,0 +1,22 @@ +#ifndef ONEFLOW_CORE_KERNEL_REDUCE_SPLIT_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_REDUCE_SPLIT_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template +class ReduceSplitKernel final : public KernelIf { + public: + OF_DISALLOW_COPY_AND_MOVE(ReduceSplitKernel); + ReduceSplitKernel() = default; + ~ReduceSplitKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_REDUCE_SPLIT_KERNEL_H_ diff --git a/oneflow/core/operator/op_conf.proto b/oneflow/core/operator/op_conf.proto index c65c98056d..44cf836ede 100644 --- a/oneflow/core/operator/op_conf.proto +++ b/oneflow/core/operator/op_conf.proto @@ -577,6 +577,14 @@ message TransposeOpConf { repeated int32 perm = 3; } +message ReduceConcatOpConf { + required int32 in_num = 1; +} + +message ReduceSplitOpConf { + required int32 out_num = 1; +} + message ReduceScatterOpConf { required int32 out_num = 1; }; @@ -652,9 +660,11 @@ message OperatorConf { ReduceLocalAddOpConf reduce_local_add_conf = 401; ReduceGlobalAddOpConf reduce_global_add_conf = 402; ReduceGatherOpConf reduce_gather_conf = 403; - RecordLoadOpConf record_load_conf = 404; - AccuracyOpConf accuracy_conf=405; - AccuracyPrintOpConf accuracy_print_conf = 406; + ReduceConcatOpConf reduce_concat_conf = 404; + ReduceSplitOpConf reduce_split_conf = 405; + RecordLoadOpConf record_load_conf = 406; + AccuracyOpConf accuracy_conf=407; + AccuracyPrintOpConf accuracy_print_conf = 408; } } diff --git a/oneflow/core/operator/reduce_concat_op.cpp b/oneflow/core/operator/reduce_concat_op.cpp new file mode 100644 index 0000000000..09861d11d8 --- /dev/null +++ b/oneflow/core/operator/reduce_concat_op.cpp @@ -0,0 +1,54 @@ +#include "oneflow/core/operator/reduce_concat_op.h" +#include "oneflow/core/common/balanced_splitter.h" +#include "oneflow/core/register/runtime_blob_desc.h" + +namespace oneflow { + +void ReduceConcatOp::InitFromOpConf() { + CHECK(op_conf().has_reduce_concat_conf()); + for (int32_t i = 0; i < op_conf().reduce_concat_conf().in_num(); ++i) { + EnrollInputBn("in_" + std::to_string(i), false); + } + EnrollOutputBn("out", false); +} + +const PbMessage& ReduceConcatOp::GetCustomizedConf() const { + return op_conf().reduce_concat_conf(); +} + +void ReduceConcatOp::InferBlobDescs(std::function GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + int32_t in_num = op_conf().reduce_concat_conf().in_num(); + CHECK_GE(in_num, 2); + BlobDesc* first_in_blob = GetBlobDesc4BnInOp(input_bns().Get(0)); + BlobDesc* out_blob = GetBlobDesc4BnInOp(SoleObn()); + *out_blob = *first_in_blob; + int64_t out_blob_elem_cnt = first_in_blob->shape().elem_cnt(); + for (int32_t i = 1; i < in_num; ++i) { + out_blob_elem_cnt += GetBlobDesc4BnInOp(input_bns().Get(i))->shape().elem_cnt(); + } + out_blob->mut_shape() = Shape({out_blob_elem_cnt}); +} + +void ReduceConcatOp::VirtualGenKernelConf( + std::function GetBlobDesc4BnInOp, const ParallelContext*, + KernelConf* kernel_conf) const { + ReduceConcatKernelConf* reduce_concat_conf = kernel_conf->mutable_reduce_concat_conf(); + int64_t offset = 0; + for (int32_t i = 0; i < op_conf().reduce_concat_conf().in_num(); ++i) { + reduce_concat_conf->mutable_data_offset()->Add(offset); + offset += RtBlobDesc(*(GetBlobDesc4BnInOp(input_bns().Get(i)))).ByteSizeOfBlobBody(); + } + CHECK_EQ(offset, RtBlobDesc(*GetBlobDesc4BnInOp(SoleObn())).ByteSizeOfBlobBody()); +} + +LogicalBlobId ReduceConcatOp::obn2lbi(const std::string& output_bn) const { + LogicalBlobId ret; + ret.set_op_name(op_name()); + ret.set_blob_name("out"); + return ret; +} + +REGISTER_OP(OperatorConf::kReduceConcatConf, ReduceConcatOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/reduce_concat_op.h b/oneflow/core/operator/reduce_concat_op.h new file mode 100644 index 0000000000..c2419c5178 --- /dev/null +++ b/oneflow/core/operator/reduce_concat_op.h @@ -0,0 +1,29 @@ +#ifndef ONEFLOW_CORE_OPERATOR_REDUCE_CONCAT_OP_H_ +#define ONEFLOW_CORE_OPERATOR_REDUCE_CONCAT_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class ReduceConcatOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(ReduceConcatOp); + ReduceConcatOp() = default; + ~ReduceConcatOp() = default; + + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override; + + void InferBlobDescs(std::function GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + + private: + void VirtualGenKernelConf(std::function GetBlobDesc4BnInOp, + const ParallelContext*, KernelConf*) const override; + LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); } + LogicalBlobId obn2lbi(const std::string& output_bn) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_REDUCE_CONCAT_OP_H_ diff --git a/oneflow/core/operator/reduce_split_op.cpp b/oneflow/core/operator/reduce_split_op.cpp new file mode 100644 index 0000000000..78995a57cc --- /dev/null +++ b/oneflow/core/operator/reduce_split_op.cpp @@ -0,0 +1,31 @@ +#include "oneflow/core/operator/reduce_split_op.h" +#include "oneflow/core/common/balanced_splitter.h" +#include "oneflow/core/register/runtime_blob_desc.h" + +namespace oneflow { + +void ReduceSplitOp::InitFromOpConf() { + CHECK(op_conf().has_reduce_split_conf()); + for (int32_t i = 0; i < op_conf().reduce_split_conf().out_num(); ++i) { + EnrollOutputBn("out_" + std::to_string(i), false); + } + EnrollInputBn("in", false); +} + +const PbMessage& ReduceSplitOp::GetCustomizedConf() const { return op_conf().reduce_split_conf(); } + +void ReduceSplitOp::VirtualGenKernelConf( + std::function GetBlobDesc4BnInOp, const ParallelContext*, + KernelConf* kernel_conf) const { + ReduceSplitKernelConf* reduce_split_conf = kernel_conf->mutable_reduce_split_conf(); + int64_t offset = 0; + for (int32_t i = 0; i < op_conf().reduce_split_conf().out_num(); ++i) { + reduce_split_conf->mutable_data_offset()->Add(offset); + offset += RtBlobDesc(*(GetBlobDesc4BnInOp(output_bns().Get(i)))).ByteSizeOfBlobBody(); + } + CHECK_EQ(offset, RtBlobDesc(*GetBlobDesc4BnInOp(SoleIbn())).ByteSizeOfBlobBody()); +} + +REGISTER_OP(OperatorConf::kReduceSplitConf, ReduceSplitOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/reduce_split_op.h b/oneflow/core/operator/reduce_split_op.h new file mode 100644 index 0000000000..a6f632fe6d --- /dev/null +++ b/oneflow/core/operator/reduce_split_op.h @@ -0,0 +1,29 @@ +#ifndef ONEFLOW_CORE_OPERATOR_REDUCE_SPLIT_OP_H_ +#define ONEFLOW_CORE_OPERATOR_REDUCE_SPLIT_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class ReduceSplitOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(ReduceSplitOp); + ReduceSplitOp() = default; + ~ReduceSplitOp() = default; + + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override; + + void InferBlobDescs(std::function GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override {} + + private: + void VirtualGenKernelConf(std::function GetBlobDesc4BnInOp, + const ParallelContext*, KernelConf*) const override; + LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); } + LogicalBlobId obn2lbi(const std::string& output_bn) const override { return GenPackedLbi(); } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_REDUCE_SPLIT_OP_H_ -- GitLab