提交 04a5794b 编写于 作者: S strickland12 提交者: Jinhui Yuan

refine chain_graph construtor (#1131)

* refine task_graph construtor

* use const qualifier

* add ordered_chain_nodes_


Former-commit-id: 945139d9
上级 444c951c
......@@ -136,64 +136,24 @@ std::string ChainNode::VisualStr() const {
}
ChainGraph::ChainGraph(const TaskGraph& task_gph) : task_gph_(task_gph) {
std::vector<TaskNode*> ordered_task_nodes;
HashMap<int64_t, std::vector<TaskNode*>> machine2tasks;
std::vector<std::vector<TaskNode*>> chains;
task_gph.AcyclicTopoForEachNode([&](TaskNode* node) { ordered_task_nodes.emplace_back(node); });
GroupTaskNodesByMachine(ordered_task_nodes, &machine2tasks);
GroupTaskNodesByMachine(task_gph, &machine2tasks);
MergeTaskNodes(machine2tasks, &chains);
for (auto& task_nodes_in_chain : chains) {
ChainNode* chain_node = new ChainNode(task_nodes_in_chain);
for (auto& task_node : task_nodes_in_chain) {
CHECK(task_node2chain_node_.emplace(task_node, chain_node).second);
}
AddAllocatedNode(chain_node);
}
for (auto& cur_task_node : ordered_task_nodes) {
auto cur_chain_node = ChainNode4TaskNode(cur_task_node);
for (auto& task_in_edge : cur_task_node->in_edges()) {
auto src_task_node = task_in_edge->src_node();
if (!cur_task_node->ancestors().count(src_task_node)) {
continue; // ignore kMdUpdt-{kNormalForward, kNormalBackward} edge
}
auto src_chain_node = ChainNode4TaskNode(src_task_node);
if (cur_chain_node == src_chain_node) continue;
if (HasChainEdge(src_chain_node, cur_chain_node)) continue;
Connect(src_chain_node, NewEdge(), cur_chain_node);
}
}
TopoForEachNode([&](ChainNode* chain_node) {
ordered_chain_nodes_.emplace_back(chain_node);
int64_t stream_id = chain_node->task_nodes().front()->GlobalWorkStreamId();
int64_t chain_id = Global<IDMgr>::Get()->AllocateChainId(stream_id);
chain_node->set_chain_id(chain_id);
for (auto& task_node : chain_node->task_nodes()) {
ordered_task_nodes_.emplace_back(task_node);
}
});
InitChainNode(chains);
InitChainEdge(chains);
SetChainId4ChainNode();
ToDotWithAutoFilePath();
}
void ChainGraph::GroupTaskNodesByMachine(const std::vector<TaskNode*>& ordered_task_nodes,
HashMap<int64_t, std::vector<TaskNode*>>* machine2tasks) {
for (auto& task_node : ordered_task_nodes) {
int64_t machine_id = task_node->machine_id();
auto machine_it = machine2tasks->find(machine_id);
if (machine_it != machine2tasks->end()) {
machine_it->second.push_back(task_node);
} else {
std::vector<TaskNode*> task_nodes{task_node};
CHECK(machine2tasks->emplace(machine_id, task_nodes).second);
}
}
void ChainGraph::GroupTaskNodesByMachine(
const TaskGraph& task_gph, HashMap<int64_t, std::vector<TaskNode*>>* machine2tasks) const {
task_gph.AcyclicTopoForEachNode(
[&](TaskNode* node) { (*machine2tasks)[node->machine_id()].emplace_back(node); });
}
void ChainGraph::MergeTaskNodes(const HashMap<int64_t, std::vector<TaskNode*>>& machine2tasks,
std::vector<std::vector<TaskNode*>>* chains) {
std::vector<std::vector<TaskNode*>>* chains) const {
int64_t machine_num = machine2tasks.size();
int64_t cpu_num = std::thread::hardware_concurrency();
int64_t thread_pool_size = std::min(machine_num, cpu_num);
......@@ -214,21 +174,46 @@ void ChainGraph::MergeTaskNodes(const HashMap<int64_t, std::vector<TaskNode*>>&
counter.WaitUntilCntEqualZero();
}
ChainNode* ChainGraph::ChainNode4TaskNode(TaskNode* task_node) const {
auto task2chain_it = task_node2chain_node_.find(task_node);
CHECK(task2chain_it != task_node2chain_node_.end());
return task2chain_it->second;
void ChainGraph::InitChainNode(const std::vector<std::vector<TaskNode*>>& chains) {
for (auto& chain : chains) {
ChainNode* chain_node = new ChainNode(chain);
for (auto& task_node : chain) {
CHECK(task_node2chain_node_.emplace(task_node, chain_node).second);
}
AddAllocatedNode(chain_node);
}
}
void ChainGraph::InitChainEdge(const std::vector<std::vector<TaskNode*>>& chains) {
for (auto& chain : chains) {
for (auto& cur_task_node : chain) {
auto cur_chain_node = ChainNode4TaskNode(cur_task_node);
for (auto& task_in_edge : cur_task_node->in_edges()) {
auto src_task_node = task_in_edge->src_node();
if (IsBackEdge(src_task_node, cur_task_node)) { continue; }
auto src_chain_node = ChainNode4TaskNode(src_task_node);
if (cur_chain_node == src_chain_node) { continue; }
if (HasChainEdge(src_chain_node, cur_chain_node)) { continue; }
Connect(src_chain_node, NewEdge(), cur_chain_node);
}
}
}
}
void ChainGraph::SetChainId4ChainNode() {
TopoForEachNode([&](ChainNode* chain_node) {
ordered_chain_nodes_.emplace_back(chain_node);
int64_t stream_id = chain_node->TaskNodes().front()->GlobalWorkStreamId();
int64_t chain_id = Global<IDMgr>::Get()->AllocateChainId(stream_id);
chain_node->SetChainId(chain_id);
});
}
bool ChainGraph::HasChainEdge(ChainNode* src, ChainNode* dst) const {
bool has_chain_edge = false;
for (auto& out_edge : src->out_edges()) {
if (out_edge->dst_node() == dst) {
has_chain_edge = true;
break;
}
if (out_edge->dst_node() == dst) { return true; }
}
return has_chain_edge;
return false;
}
} // namespace oneflow
......@@ -33,12 +33,12 @@ class ChainNode final : public Node<ChainNode, ChainEdge> {
virtual ~ChainNode() = default;
std::string VisualStr() const override;
const std::vector<TaskNode*>& task_nodes() const { return task_nodes_; }
const std::vector<TaskNode*>& TaskNodes() const { return task_nodes_; }
int64_t chain_id() const {
CHECK_NE(chain_id_, -1);
return chain_id_;
}
void set_chain_id(int64_t val) { chain_id_ = val; }
void SetChainId(int64_t val) { chain_id_ = val; }
private:
std::vector<TaskNode*> task_nodes_;
......@@ -64,21 +64,25 @@ class ChainGraph final : public Graph<ChainNode, ChainEdge> {
ChainGraph(const TaskGraph& task_gph);
const char* TypeName() const override { return "ChainGraph"; }
const std::vector<ChainNode*>& ordered_chain_nodes() const { return ordered_chain_nodes_; }
const std::vector<TaskNode*>& ordered_task_nodes() const { return ordered_task_nodes_; }
const std::vector<ChainNode*>& OrderdedChainNodes() const { return ordered_chain_nodes_; }
private:
ChainNode* ChainNode4TaskNode(TaskNode* task_node) const;
bool HasChainEdge(ChainNode* src, ChainNode* dst) const;
void GroupTaskNodesByMachine(const std::vector<TaskNode*>& ordered_task_nodes,
HashMap<int64_t, std::vector<TaskNode*>>* machine2tasks);
ChainNode* ChainNode4TaskNode(TaskNode* task_node) const {
return task_node2chain_node_.at(task_node);
}
void GroupTaskNodesByMachine(const TaskGraph& task_gph,
HashMap<int64_t, std::vector<TaskNode*>>* machine2tasks) const;
void MergeTaskNodes(const HashMap<int64_t, std::vector<TaskNode*>>& machine2tasks,
std::vector<std::vector<TaskNode*>>* chains);
std::vector<std::vector<TaskNode*>>* chains) const;
void InitChainNode(const std::vector<std::vector<TaskNode*>>& chains);
void InitChainEdge(const std::vector<std::vector<TaskNode*>>& chains);
void SetChainId4ChainNode();
const TaskGraph& task_gph_;
HashMap<TaskNode*, ChainNode*> task_node2chain_node_;
std::vector<ChainNode*> ordered_chain_nodes_;
std::vector<TaskNode*> ordered_task_nodes_;
};
} // namespace oneflow
......
......@@ -76,10 +76,10 @@ void TaskGraph::FindChainsInSameStream() {
CollectAncestorsForEachNode();
ChainGraph chain_gph(*this);
const auto& ordered_chain_nodes = chain_gph.ordered_chain_nodes();
const auto& ordered_chain_nodes = chain_gph.OrderdedChainNodes();
int64_t order_in_graph = 0;
for (auto& chain_node : ordered_chain_nodes) {
auto& ordered_in_chain = chain_node->task_nodes();
auto& ordered_in_chain = chain_node->TaskNodes();
int64_t chain_id = chain_node->chain_id();
for (auto& task_node : ordered_in_chain) {
task_node->set_chain_id(chain_id);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册