提交 6d9663c7 编写于 作者: N Niu Chong 提交者: Jinhui Yuan

feat: refactor reduce struct to control network order, avoiding network contention (#994)

* feat: avoid net contention by adding ctrl edge in ReduceStruct

* refine(task_graph.h/cpp): refine AddCtrlEdgeInReduceStruct()

* fix(graph/task_graph.cpp): fix the bug of machine order

* fix(graph/task_graph.cpp): do not add ctrl edge with reduce scatter

* feat: add ReduceGlobalAddCompActor

* fix: fix the bug of reduce_global_actor/kernel

* chore: remove used vim .swp file

* fix(graph/task_graph.cpp): fix the bug of sorting copycomment when build reduce ctrl edge

* fix(graph/task_graph.h/cpp): add CtrlEdge for ReduceGather

* revert: remove the ReduceGlobalAddCompActor from this PR

* feat: add use_ordered_allreduce_in_mdupdt in OtherConf
上级 ea965893
......@@ -51,10 +51,10 @@ OperatorConf CopyHdTaskNode::NewCopyOpConf() {
return conf;
}
void CopyCommNetTaskNode::Init(int64_t machine_id, int64_t peer_machine_id) {
void CopyCommNetTaskNode::Init(int64_t machine_id, int64_t src_machine_id) {
set_machine_id(machine_id);
set_thrd_id(Global<IDMgr>::Get()->CommNetThrdId());
peer_machine_id_ = peer_machine_id;
peer_machine_id_ = src_machine_id;
}
namespace {
......
......@@ -59,6 +59,7 @@ class CopyCommNetTaskNode final : public CopyTaskNode {
void Init(int64_t machine_id, int64_t src_machine_id);
int64_t AllocateLocalWorkStreamId() override;
int64_t peer_machine_id() const { return peer_machine_id_; }
private:
void InitProducedRegstMemCase(MemoryCase*) override;
......
#include "oneflow/core/graph/task_graph.h"
#include "oneflow/core/graph/chain_graph.h"
#include "oneflow/core/graph/boxing_task_node.h"
#include "oneflow/core/graph/copy_task_node.h"
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/graph/reduce_global_add_compute_task_node.h"
#include "oneflow/core/graph/reduce_gather_compute_task_node.h"
namespace oneflow {
......@@ -89,6 +90,98 @@ void TaskGraph::AddOrderingCtrlEdgeInSameChain() {
}
}
void TaskGraph::AddCtrlEdgeInReduceStruct() {
int64_t total_machine_num = Global<JobDesc>::Get()->resource().machine().size();
if (total_machine_num == 1) { return; }
AddCtrlEdgeForReduceTaskNode<ReduceGlobalAddLogicalNode, ReduceGlobalAddCompTaskNode>(
total_machine_num);
AddCtrlEdgeForReduceTaskNode<ReduceGatherLogicalNode, ReduceGatherCompTaskNode>(
total_machine_num);
}
template<typename LogicalNodeType, typename TaskNodeType>
void TaskGraph::AddCtrlEdgeForReduceTaskNode(int64_t total_machine_num) {
HashMap<const LogicalNodeType*, HashMap<int64_t, std::vector<TaskNodeType*>>>
machine_id2reduce_task_nodes4same_logical_node;
ForEachNode([&](TaskNode* task_node) {
TaskNodeType* reduce_task_node = dynamic_cast<TaskNodeType*>(task_node);
if (reduce_task_node != nullptr) {
const LogicalNodeType* logical_node =
dynamic_cast<const LogicalNodeType*>(reduce_task_node->logical_node());
CHECK(logical_node != nullptr);
machine_id2reduce_task_nodes4same_logical_node[logical_node][reduce_task_node->machine_id()]
.push_back(reduce_task_node);
}
});
for (const auto& kv : machine_id2reduce_task_nodes4same_logical_node) {
const auto& machine_id2reduce_task_nodes = kv.second;
if (machine_id2reduce_task_nodes.size() == 1) { continue; }
for (int64_t machine_id = 0; machine_id < machine_id2reduce_task_nodes.size(); ++machine_id) {
std::vector<std::pair<CopyCommNetTaskNode*, int64_t>> commnet_nodes_with_sort_val;
CollectCopyCommNetForReduceTaskNodes(machine_id2reduce_task_nodes.at(machine_id),
&commnet_nodes_with_sort_val);
std::vector<int64_t> machine_id2sort_order(total_machine_num);
for (size_t i = 0; i < total_machine_num; ++i) {
machine_id2sort_order.at(i) = (i + total_machine_num - machine_id - 1) % total_machine_num;
}
std::sort(commnet_nodes_with_sort_val.begin(), commnet_nodes_with_sort_val.end(),
[&](const std::pair<CopyCommNetTaskNode*, int64_t>& lhs,
const std::pair<CopyCommNetTaskNode*, int64_t>& rhs) {
if (lhs.first->peer_machine_id() == rhs.first->peer_machine_id()) {
return lhs.second < rhs.second;
}
return machine_id2sort_order.at(lhs.first->peer_machine_id())
< machine_id2sort_order.at(rhs.first->peer_machine_id());
});
for (size_t i = 0; i < commnet_nodes_with_sort_val.size() - 1; ++i) {
commnet_nodes_with_sort_val.at(i).first->BuildCtrlRegstDescIfNeed(
commnet_nodes_with_sort_val.at(i + 1).first);
}
}
}
}
template<typename TaskNodeType>
void TaskGraph::CollectCopyCommNetForReduceTaskNodes(
const std::vector<TaskNodeType*>& reduce_task_nodes,
std::vector<std::pair<CopyCommNetTaskNode*, int64_t>>* commnet_nodes_with_sort_val) {
HashSet<CopyCommNetTaskNode*> inserted_commnet_nodes;
for (TaskNodeType* reduce_task_node : reduce_task_nodes) {
for (TaskEdge* in_edge : reduce_task_node->in_edges()) {
TaskNode* pre_node = in_edge->src_node();
while (IsEndingTaskType<TaskNodeType>(pre_node->GetTaskType()) == false) {
if (pre_node->GetTaskType() == TaskType::kCopyCommNet) {
CopyCommNetTaskNode* commnet_node = dynamic_cast<CopyCommNetTaskNode*>(pre_node);
CHECK(commnet_node != nullptr);
if (inserted_commnet_nodes.find(commnet_node) == inserted_commnet_nodes.end()) {
commnet_nodes_with_sort_val->emplace_back(
commnet_node, reduce_task_node->parallel_ctx()->parallel_id());
inserted_commnet_nodes.insert(commnet_node);
}
break;
}
pre_node = pre_node->SoleInEdge()->src_node();
}
}
}
}
template<>
bool TaskGraph::IsEndingTaskType<ReduceGlobalAddCompTaskNode>(TaskType type) {
return type == TaskType::kReduceLocalAdd;
}
template<>
bool TaskGraph::IsEndingTaskType<ReduceGatherCompTaskNode>(TaskType type) {
return type == TaskType::kReduceGlobalAdd;
}
void TaskGraph::AddMutexCtrlEdgeInSameChain() { UNIMPLEMENTED(); }
void TaskGraph::AddOrderCtrlEdgeBetweenCopyAndMdUpdt() { UNIMPLEMENTED(); }
......
......@@ -5,6 +5,7 @@
#include "oneflow/core/job/id_manager.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/graph/copy_task_node.h"
namespace oneflow {
......@@ -18,6 +19,7 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
const char* TypeName() const override { return "TaskGraph"; }
void AddOrderingCtrlEdgeInSameChain();
void AddCtrlEdgeInReduceStruct();
void AddMutexCtrlEdgeInSameChain();
void AddOrderCtrlEdgeBetweenCopyAndMdUpdt();
void AcyclicTopoForEachNode(std::function<void(TaskNode* node)> handler) const;
......@@ -60,6 +62,17 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
void CollectAncestorsForEachNode();
void FindChainsInSameStream();
template<typename LogicalNodeType, typename TaskNodeType>
void AddCtrlEdgeForReduceTaskNode(int64_t total_machine_num);
template<typename TaskNodeType>
void CollectCopyCommNetForReduceTaskNodes(
const std::vector<TaskNodeType*>& reduce_task_nodes,
std::vector<std::pair<CopyCommNetTaskNode*, int64_t>>* commnet_nodes_with_sort_val);
template<typename TaskNodeType>
bool IsEndingTaskType(TaskType type);
std::unique_ptr<const LogicalGraph> logical_gph_;
std::vector<TaskNode*> ordered_task_nodes_;
};
......
......@@ -60,6 +60,9 @@ Plan Compiler::DoCompile() {
task_gph->ForEachNode(std::bind(&TaskNode::EraseEmptyProducedRegst, _1));
task_gph->ForEachNode(std::bind(&TaskNode::ClearOutOfDateConsumedRegst, _1));
task_gph->AddOrderingCtrlEdgeInSameChain();
if (Global<JobDesc>::Get()->other_conf().use_ordered_allreduce_in_mdupdt()) {
task_gph->AddCtrlEdgeInReduceStruct();
}
Plan plan;
task_gph->ForEachNode([&](TaskNode* task_node) {
if (task_node->IsMeaningLess()) { return; }
......
......@@ -58,6 +58,7 @@ message OtherConf {
optional bool collect_act_event = 113 [default = false];
optional bool enable_mem_sharing = 114 [default = true];
optional bool use_ordered_allreduce_in_mdupdt = 115 [default = false];
oneof JobType {
TrainConf train_conf = 200;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册