未验证 提交 40c299bc 编写于 作者: J Jinhui Yuan 提交者: GitHub

rm duplicate ReduceTaskNodes caused by ReduceConcat&Split (#1179)

上级 3aa79c70
......@@ -133,26 +133,17 @@ void TaskGraph::BuildCtrlRegstDescInSameChain() {
}
}
struct ReduceTaskNodes {
CompTaskNode* concat = nullptr;
CompTaskNode* scatter = nullptr;
CompTaskNode* local_add = nullptr;
CompTaskNode* global_add = nullptr;
CompTaskNode* gather = nullptr;
CompTaskNode* split = nullptr;
};
void TaskGraph::EnableMemSharingInReduceStruct() {
HashMap<CompTaskNode*, ReduceTaskNodes> bw2reduce_tasks;
CollectReduceTaskNodes(&bw2reduce_tasks);
for (auto& pair : bw2reduce_tasks) {
EnableMemSharingInOneReduce(pair.second);
AddCtrlEdge4MemSharingInOneReduce(pair.second);
std::unordered_set<ReduceTaskNodes, ReduceTaskNodesHasher> reduce_tasks;
CollectReduceTaskNodes(&reduce_tasks);
for (auto& reduce_task : reduce_tasks) {
EnableMemSharingInOneReduce(reduce_task);
AddCtrlEdge4MemSharingInOneReduce(reduce_task);
}
}
void TaskGraph::CollectReduceTaskNodes(
HashMap<CompTaskNode*, ReduceTaskNodes>* bw2reduce_tasks) const {
std::unordered_set<ReduceTaskNodes, ReduceTaskNodesHasher>* reduce_tasks) const {
auto FindSuccReduceTaskNode = [](CompTaskNode* task_node, TaskType type) -> CompTaskNode* {
for (TaskEdge* out_edge : task_node->out_edges()) {
TaskNode* dst_node = out_edge->dst_node();
......@@ -186,7 +177,7 @@ void TaskGraph::CollectReduceTaskNodes(
return;
}
ReduceTaskNodes& reduce_task_nodes = (*bw2reduce_tasks)[bw_task_node];
ReduceTaskNodes reduce_task_nodes;
CompTaskNode* diff_acc_task_node = FindSuccReduceTaskNode(bw_task_node, TaskType::kMdDiffAcc);
if (diff_acc_task_node != nullptr) {
FindConcatAndScatter(diff_acc_task_node, &reduce_task_nodes);
......@@ -212,6 +203,7 @@ void TaskGraph::CollectReduceTaskNodes(
CHECK(reduce_task_nodes.scatter != nullptr);
CHECK(reduce_task_nodes.global_add != nullptr);
CHECK(reduce_task_nodes.gather != nullptr);
reduce_tasks->insert(reduce_task_nodes);
});
}
......
......@@ -9,7 +9,26 @@
namespace oneflow {
class ReduceTaskNodes;
struct ReduceTaskNodes {
CompTaskNode* concat = nullptr;
CompTaskNode* scatter = nullptr;
CompTaskNode* local_add = nullptr;
CompTaskNode* global_add = nullptr;
CompTaskNode* gather = nullptr;
CompTaskNode* split = nullptr;
bool operator==(const ReduceTaskNodes& rhs) const {
return this->concat == rhs.concat && this->scatter == rhs.scatter
&& this->local_add == rhs.local_add && this->global_add == rhs.global_add
&& this->gather == rhs.gather && this->split == rhs.split;
}
};
struct ReduceTaskNodesHasher {
std::size_t operator()(const ReduceTaskNodes& key) const {
return (size_t)(key.concat) ^ (size_t)(key.scatter) ^ (size_t)(key.local_add)
^ (size_t)(key.global_add) ^ (size_t)(key.gather) ^ (size_t)(key.split);
}
};
class TaskGraph final : public Graph<TaskNode, TaskEdge> {
public:
......@@ -24,7 +43,7 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
void AddOrderingCtrlEdgeInSameChain();
void EnableMemSharingInReduceStruct();
void CollectReduceTaskNodes(HashMap<CompTaskNode*, ReduceTaskNodes>*) const;
void CollectReduceTaskNodes(std::unordered_set<ReduceTaskNodes, ReduceTaskNodesHasher>*) const;
void EnableMemSharingInOneReduce(const ReduceTaskNodes&);
void AddCtrlEdge4MemSharingInOneReduce(const ReduceTaskNodes&);
void BuildCtrlRegstBetweenReduceCopyNodes(const CompTaskNode* src_reduce,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册