未验证 提交 d05a7b3c 编写于 作者: J Jinhui Yuan 提交者: GitHub

Add chain (#947)

* add PathType

* add PathType for LogicalNode

* add PathType for task node

* fix if BldSubTskGphByOneToOne has null logical input

* simplify SetPathType for new task nodes

* add chain graph

* UncyclicTopoForEachNode

* simpify UncyclicTopoForEachNode

* CollectAncestorsForEachTaskNode

* AddOrderCtrlEdgeBetweenCopyAndMdUpdt

* Chain merge

* fix

* fix chain merge algorithm

* use both ancestors and descendants for merging chain

* change Path to Area

* fix chain merge

* chain graph works

* remove descendants of task node

* refine

* fill chain_id and order_in_graph

* refine

* fix iterator error

* Use area id
上级 75291d55
......@@ -27,6 +27,17 @@
DECLARE_string(log_dir);
namespace std {
template<typename T0, typename T1>
struct hash<std::pair<T0, T1>> {
std::size_t operator()(const std::pair<T0, T1>& p) const {
auto h0 = std::hash<T0>{}(p.first);
auto h1 = std::hash<T1>{}(p.second);
return h0 ^ h1;
}
};
} // namespace std
namespace oneflow {
#define OF_DISALLOW_COPY(ClassName) \
......
#include "oneflow/core/graph/chain_graph.h"
#include "oneflow/core/graph/task_graph.h"
#include "oneflow/core/graph/task_node.h"
namespace oneflow {
namespace {
void InitChains(const std::vector<TaskNode*>& ordered_nodes, std::list<Chain>* chain_list,
Task2ChainItMap* task2chain_it) {
chain_list->clear();
task2chain_it->clear();
for (const auto& task_node : ordered_nodes) {
chain_list->emplace_back();
task2chain_it->insert({task_node, --chain_list->end()});
Chain& cur_chain = chain_list->back();
cur_chain.ancestors.clear();
cur_chain.ancestors_and_this.clear();
cur_chain.nodes = {task_node};
cur_chain.area_id = task_node->area_id();
cur_chain.stream_id = task_node->GlobalWorkStreamId();
cur_chain.ancestors_and_this.insert(cur_chain.nodes.begin(), cur_chain.nodes.end());
cur_chain.ancestors.insert(task_node->ancestors().begin(), task_node->ancestors().end());
cur_chain.ancestors_and_this.insert(cur_chain.ancestors.begin(), cur_chain.ancestors.end());
}
}
bool IsConnected(ChainIt src_chain_it, ChainIt dst_chain_it, const Task2ChainItMap& task2chain) {
for (auto& dst_task_node : dst_chain_it->nodes) {
for (auto& in_edge : dst_task_node->in_edges()) {
auto src_task_node = in_edge->src_node();
if (IsBackEdge(src_task_node, dst_task_node)) { continue; }
auto src_chain_it_it = task2chain.find(src_task_node);
if (src_chain_it_it != task2chain.end()) {
if (src_chain_it == src_chain_it_it->second) { return true; }
}
}
}
return false;
}
bool DoMergeWithConnect(std::list<ChainIt>& chains, ChainIt rhs, Task2ChainItMap* task2chain_it) {
for (auto chains_it = chains.rbegin(); chains_it != chains.rend(); ++chains_it) {
ChainIt lhs = *chains_it;
if (IsConnected(lhs, rhs, *task2chain_it) && lhs->ancestors_and_this == rhs->ancestors) {
for (TaskNode* node : rhs->nodes) {
lhs->nodes.push_back(node);
lhs->ancestors_and_this.insert(node);
task2chain_it->at(node) = lhs;
}
return true;
}
}
return false;
}
bool DoMergeWithoutConnect(std::list<ChainIt>& chains, ChainIt rhs,
Task2ChainItMap* task2chain_it) {
for (auto chains_it = chains.rbegin(); chains_it != chains.rend(); ++chains_it) {
ChainIt lhs = *chains_it;
if (!IsConnected(lhs, rhs, *task2chain_it) && lhs->ancestors == rhs->ancestors) {
for (TaskNode* node : rhs->nodes) {
lhs->nodes.push_back(node);
lhs->ancestors_and_this.insert(node);
task2chain_it->at(node) = lhs;
}
return true;
}
}
return false;
}
bool TryMerge(
std::list<Chain>* chain_list, Task2ChainItMap* task2chain_it,
std::function<bool(std::list<ChainIt>& chains, ChainIt cur_it, Task2ChainItMap* task2chain_it)>
DoMerge) {
HashMap<std::pair<int64_t, int64_t>, std::list<ChainIt>> stream_area2chains;
bool merge_happened = false;
for (auto cur_chain_it = chain_list->begin(); cur_chain_it != chain_list->end();) {
std::pair<int64_t, int64_t> stream_area_id = {cur_chain_it->stream_id, cur_chain_it->area_id};
auto stream_area_it = stream_area2chains.find(stream_area_id);
if (stream_area_it != stream_area2chains.end()
&& DoMerge(stream_area_it->second, cur_chain_it, task2chain_it)) {
cur_chain_it = chain_list->erase(cur_chain_it);
merge_happened = true;
} else {
stream_area2chains[stream_area_id].push_back(cur_chain_it);
++cur_chain_it;
}
}
return merge_happened;
}
void MergeChains(std::list<Chain>* chain_list, Task2ChainItMap* task2chain_it) {
while (TryMerge(chain_list, task2chain_it, DoMergeWithConnect)
|| TryMerge(chain_list, task2chain_it, DoMergeWithoutConnect)) {}
}
} // namespace
std::string ChainNode::VisualStr() const {
std::stringstream ss;
ss << "chain_id:" << chain_id_ << "\\n";
for (auto& task_node : chain_it_->nodes) {
ss << TaskType_Name(task_node->GetTaskType()) << ":" << task_node->task_id() << "\\n";
}
return ss.str();
}
ChainGraph::ChainGraph(const TaskGraph& task_gph) : task_gph_(task_gph) {
std::vector<TaskNode*> ordered_task_nodes;
task_gph.AcyclicTopoForEachNode([&](TaskNode* node) { ordered_task_nodes.emplace_back(node); });
InitChains(ordered_task_nodes, &chain_list_, &task_node2chain_it_);
MergeChains(&chain_list_, &task_node2chain_it_);
for (auto chain_it = chain_list_.begin(); chain_it != chain_list_.end(); ++chain_it) {
ChainNode* chain_node = new ChainNode(chain_it);
chain_it->chain_node = chain_node;
AddAllocatedNode(chain_node);
}
for (auto& cur_task_node : ordered_task_nodes) {
auto cur_chain_node = ChainNode4TaskNode(cur_task_node);
for (auto& task_in_edge : cur_task_node->in_edges()) {
auto src_task_node = task_in_edge->src_node();
if (!cur_task_node->ancestors().count(src_task_node)) {
continue; // ignore kMdUpdt-{kNormalForward, kNormalBackward} edge
}
auto src_chain_node = ChainNode4TaskNode(src_task_node);
if (cur_chain_node == src_chain_node) continue;
if (HasChainEdge(src_chain_node, cur_chain_node)) continue;
Connect(src_chain_node, NewEdge(), cur_chain_node);
}
}
TopoForEachNode([&](ChainNode* chain_node) {
ordered_chain_nodes_.emplace_back(chain_node);
int64_t stream_id = chain_node->chain_it()->nodes.front()->GlobalWorkStreamId();
int64_t chain_id = Global<IDMgr>::Get()->AllocateChainId(stream_id);
chain_node->set_chain_id(chain_id);
for (auto& task_node : chain_node->chain_it()->nodes) {
ordered_task_nodes_.emplace_back(task_node);
}
});
ToDotWithAutoFilePath();
}
ChainNode* ChainGraph::ChainNode4TaskNode(TaskNode* task_node) const {
auto task2chain_it = task_node2chain_it_.find(task_node);
CHECK(task2chain_it != task_node2chain_it_.end());
return task2chain_it->second->chain_node;
}
bool ChainGraph::HasChainEdge(ChainNode* src, ChainNode* dst) const {
bool has_chain_edge = false;
for (auto& out_edge : src->out_edges()) {
if (out_edge->dst_node() == dst) {
has_chain_edge = true;
break;
}
}
return has_chain_edge;
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_CHAIN_GRAPH_H_
#define ONEFLOW_CORE_GRAPH_CHAIN_GRAPH_H_
#include "oneflow/core/graph/graph.h"
namespace oneflow {
class TaskNode;
class ChainNode;
struct Chain {
// nodes belong to this chain
std::vector<TaskNode*> nodes;
// ancestors of the nodes in this chain
HashSet<TaskNode*> ancestors;
// ancestors_and_this = nodes + ancestors
HashSet<TaskNode*> ancestors_and_this;
int64_t stream_id;
int64_t area_id;
ChainNode* chain_node;
};
using ChainIt = std::list<Chain>::iterator;
using Task2ChainItMap = HashMap<const TaskNode*, ChainIt>;
class ChainEdge;
class ChainNode final : public Node<ChainNode, ChainEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(ChainNode);
explicit ChainNode(ChainIt chain_it) : chain_it_(chain_it), chain_id_(-1) {}
virtual ~ChainNode() = default;
std::string VisualStr() const override;
ChainIt chain_it() const { return chain_it_; }
const std::vector<TaskNode*>& ordered_task_nodes() const { return chain_it_->nodes; }
int64_t chain_id() const {
CHECK_NE(chain_id_, -1);
return chain_id_;
}
void set_chain_id(int64_t val) { chain_id_ = val; }
private:
ChainIt chain_it_;
int64_t chain_id_;
};
class ChainEdge final : public Edge<ChainNode, ChainEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(ChainEdge);
ChainEdge() = default;
~ChainEdge() = default;
private:
};
class TaskGraph;
class ChainGraph final : public Graph<ChainNode, ChainEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(ChainGraph);
ChainGraph() = delete;
~ChainGraph() = default;
ChainGraph(const TaskGraph& task_gph);
const char* TypeName() const override { return "ChainGraph"; }
const std::vector<ChainNode*>& ordered_chain_nodes() const { return ordered_chain_nodes_; }
const std::vector<TaskNode*>& ordered_task_nodes() const { return ordered_task_nodes_; }
private:
ChainNode* ChainNode4TaskNode(TaskNode* task_node) const;
bool HasChainEdge(ChainNode* src, ChainNode* dst) const;
const TaskGraph& task_gph_;
std::list<Chain> chain_list_;
Task2ChainItMap task_node2chain_it_;
std::vector<ChainNode*> ordered_chain_nodes_;
std::vector<TaskNode*> ordered_task_nodes_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_CHAIN_GRAPH_H_
......@@ -16,6 +16,7 @@ LogicalGraph::LogicalGraph(bool is_train) {
BuildModelStruct(is_train);
BuildRecordLoadStruct();
if (is_train) { ConnectFwToBw(); }
FixDecodeAndMdDiffAccAreaType();
ToDotWithAutoFilePath();
}
......@@ -29,6 +30,30 @@ void LogicalGraph::ForEachLogicalNode(std::function<void(LogicalNodeType*)> func
for (LogicalNodeType* valid_node : valid_nodes) { func(valid_node); }
}
void LogicalGraph::SetAreaIdForNewNodes(AreaType area_type) {
ForEachNode([&](LogicalNode* node) {
if (node->area_id() == static_cast<int64_t>(kInvalidArea)) {
node->set_area_id(static_cast<int64_t>(area_type));
}
});
}
void LogicalGraph::FixDecodeAndMdDiffAccAreaType() {
ForEachNode([&](LogicalNode* node) {
CHECK_NE(node->area_id(), 0);
auto decode_node = dynamic_cast<DecodeLogicalNode*>(node);
if (decode_node) {
CHECK_EQ(node->area_id(), static_cast<int64_t>(kDataForwardArea));
node->set_area_id(static_cast<int64_t>(kDataPreprocessArea));
}
auto md_diff_acc_node = dynamic_cast<MdDiffAccLogicalNode*>(node);
if (md_diff_acc_node) {
CHECK_EQ(node->area_id(), static_cast<int64_t>(kMdUpdtArea));
node->set_area_id(static_cast<int64_t>(kDataBackwardArea));
}
});
}
void LogicalGraph::BuildFwStruct(HashMap<LogicalEdge*, std::string>* edge2ibn) {
HashMap<std::string, std::vector<LogicalNode*>> op_name2nodes;
NaiveBuildFwStruct(edge2ibn, &op_name2nodes);
......@@ -39,6 +64,7 @@ void LogicalGraph::BuildFwStruct(HashMap<LogicalEdge*, std::string>* edge2ibn) {
total_mbn_num_ +=
node->SoleOp()->model_bns().size() + node->SoleOp()->forward_model_bns().size();
});
SetAreaIdForNewNodes(kDataForwardArea);
}
void LogicalGraph::NaiveBuildFwStruct(
......@@ -212,6 +238,7 @@ void LogicalGraph::SetMainModelParallel() {
void LogicalGraph::BuildBwStruct(HashMap<LogicalEdge*, std::string>* edge2ibn) {
NaiveBuildBwStruct(edge2ibn);
AddBackwardClone(*edge2ibn);
SetAreaIdForNewNodes(kDataBackwardArea);
}
void LogicalGraph::NaiveBuildBwStruct(HashMap<LogicalEdge*, std::string>* edge2ibn) {
......@@ -382,6 +409,7 @@ void LogicalGraph::BuildLossPrintStruct() {
loss_print_logical->mut_parallel_desc().reset(new ParallelDesc(loss_print_pr_conf));
Connect<LogicalNode>(loss_acc_logical, NewEdge(), loss_print_logical);
});
SetAreaIdForNewNodes(kPrintArea);
}
void LogicalGraph::BuildModelStruct(bool is_train) {
......@@ -437,6 +465,7 @@ void LogicalGraph::BuildModelStruct(bool is_train) {
}
});
SetupNormalMdUpdtOp();
SetAreaIdForNewNodes(kMdUpdtArea);
}
void LogicalGraph::BuildReduceStruct(LogicalNode* src, LogicalNode* dst) {
......@@ -507,6 +536,7 @@ MdSaveLogicalNode* LogicalGraph::BuildMdSaveStruct(const ForwardLogicalNode* fw_
md_save_pr_desc->set_device_type(DeviceType::kCPU);
md_save_logical->mut_parallel_desc().reset(md_save_pr_desc);
Connect<LogicalNode>(need_save_logical, NewEdge(), md_save_logical);
SetAreaIdForNewNodes(kMdSaveArea);
return md_save_logical;
}
......@@ -560,6 +590,7 @@ void LogicalGraph::BuildRecordLoadStruct() {
}
}
}
SetAreaIdForNewNodes(kDataPreprocessArea);
}
void LogicalGraph::ConnectFwToBw() {
......
......@@ -66,6 +66,9 @@ class LogicalGraph final : public Graph<LogicalNode, LogicalEdge> {
void BuildRecordLoadStruct();
void ConnectFwToBw();
void SetAreaIdForNewNodes(AreaType area_type);
void FixDecodeAndMdDiffAccAreaType();
int64_t total_mbn_num_;
};
......
......@@ -52,8 +52,14 @@ class LogicalNode : public Node<LogicalNode, LogicalEdge> {
int32_t GetModelSplitAxis() const;
int32_t GetMaxModelSplitNum() const;
void set_area_id(int64_t val) {
CHECK_NE(val, 0);
area_id_ = val;
}
int64_t area_id() const { return area_id_; }
protected:
LogicalNode() : main_model_parallel_(nullptr) {}
LogicalNode() : main_model_parallel_(nullptr), area_id_(0) {}
virtual CompTaskNode* NewCompTaskNode() const = 0;
virtual void FixCompTaskNode(CompTaskNode*) const {}
......@@ -68,6 +74,7 @@ class LogicalNode : public Node<LogicalNode, LogicalEdge> {
HashMap<const LogicalNode*, std::vector<LogicalBlobId>> dst2data_lbis_;
HashSet<LogicalBlobId> lbi_boxing_;
HashSet<LogicalBlobId> lbi_121_;
int64_t area_id_;
};
#define BLD_SUB_TSK_GPH_MTHD_ARGS() \
......
#include "oneflow/core/graph/task_graph.h"
#include "oneflow/core/graph/chain_graph.h"
#include "oneflow/core/graph/boxing_task_node.h"
#include "oneflow/core/graph/copy_task_node.h"
#include "oneflow/core/common/util.h"
namespace oneflow {
......@@ -37,6 +39,7 @@ TaskGraph::TaskGraph(std::unique_ptr<const LogicalGraph>&& logical_gph) {
AllocateCpuThrdIdEvenly, [&](CompTaskNode* comp_task_node) {
AddAllocatedNode(comp_task_node);
logical2sorted_comp_tasks[logical_node].push_back(comp_task_node);
comp_task_node->set_area_id(logical_node->area_id());
});
});
logical_gph_->ForEachEdge([&](const LogicalEdge* logical_edge) {
......@@ -46,10 +49,100 @@ TaskGraph::TaskGraph(std::unique_ptr<const LogicalGraph>&& logical_gph) {
logical2sorted_comp_tasks.at(logical_edge->src_node()),
logical2sorted_comp_tasks.at(logical_edge->dst_node()), &logical2sorted_in_box,
&logical2sorted_out_box, Mut121BufTask, AllocateCpuThrdIdEvenly);
SetAreaIdForNewNodes(logical_edge->src_node(), logical_edge->dst_node());
});
ToDotWithAutoFilePath();
}
void TaskGraph::FindChainsInSameStream() {
CollectAncestorsForEachNode();
ChainGraph chain_gph(*this);
const auto& ordered_chain_nodes = chain_gph.ordered_chain_nodes();
int64_t order_in_graph = 0;
HashMap<int64_t, HashSet<TaskNode*>> chain_id2task_nodes;
for (auto& chain_node : ordered_chain_nodes) {
auto& ordered_in_chain = chain_node->chain_it()->nodes;
int64_t chain_id = chain_node->chain_id();
for (auto& task_node : ordered_in_chain) {
task_node->set_chain_id(chain_id);
task_node->set_order_in_graph(order_in_graph);
CHECK(chain_id2task_nodes[chain_id].insert(task_node).second);
ordered_task_nodes_.emplace_back(task_node);
++order_in_graph;
}
}
}
void TaskGraph::AddOrderingCtrlEdgeInSameChain() {
FindChainsInSameStream();
HashMap<int64_t, TaskNode*> chain_id2node;
for (auto node : ordered_task_nodes_) {
int64_t chain_id = node->chain_id();
auto iter = chain_id2node.find(chain_id);
if (iter == chain_id2node.end()) {
CHECK(chain_id2node.emplace(chain_id, node).second);
} else {
iter->second->BuildCtrlRegstDescIfNeed(node);
iter->second = node;
}
}
}
void TaskGraph::AddMutexCtrlEdgeInSameChain() { UNIMPLEMENTED(); }
void TaskGraph::AddOrderCtrlEdgeBetweenCopyAndMdUpdt() { UNIMPLEMENTED(); }
void TaskGraph::CollectAncestorsForEachNode() {
std::vector<TaskNode*> ordered_nodes;
AcyclicTopoForEachNode([&](TaskNode* node) { ordered_nodes.emplace_back(node); });
for (auto it = ordered_nodes.begin(); it != ordered_nodes.end(); ++it) {
TaskNode* task_node = *it;
task_node->mut_ancestors().clear();
task_node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) {
if (IsBackEdge(node_on_in_edge, task_node)) return;
task_node->mut_ancestors().insert(node_on_in_edge->ancestors().begin(),
node_on_in_edge->ancestors().end());
task_node->mut_ancestors().insert(node_on_in_edge);
});
}
}
void TaskGraph::AcyclicTopoForEachNode(std::function<void(TaskNode* node)> handler) const {
std::list<TaskNode*> starts;
ForEachNode([&](TaskNode* node) {
if (node->consumed_regsts().empty() && !node->IsMeaningLess()) { starts.push_back(node); }
});
auto ForEachInNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& handler) {
node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) {
if (IsBackEdge(node_on_in_edge, node)) return;
handler(const_cast<TaskNode*>(node_on_in_edge));
});
};
auto ForEachOutNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& handler) {
node->ForEachNodeOnOutEdge([&](TaskNode* node_on_out_edge) {
if (IsBackEdge(node, node_on_out_edge)) return;
handler(const_cast<TaskNode*>(node_on_out_edge));
});
};
TopoForEachNode(starts, ForEachInNode, ForEachOutNode, handler);
}
void TaskGraph::SetAreaIdForNewNodes(const LogicalNode* src_logical,
const LogicalNode* dst_logical) {
CHECK(src_logical != nullptr && dst_logical != nullptr);
ForEachNode([&](TaskNode* node) {
if (node->area_id() != static_cast<int64_t>(kInvalidArea)) return;
if (src_logical->area_id() == dst_logical->area_id()) {
node->set_area_id(src_logical->area_id());
} else {
node->set_area_id(static_cast<int64_t>(kBoundaryArea));
}
});
}
#define DEFINE_BLD_SUB_TASK_GRAPH_METHOD(method_name) \
void TaskGraph::method_name BLD_SUB_TSK_GPH_MTHD_ARGS()
......@@ -264,4 +357,10 @@ void TaskGraph::ConnectWithCopyCommNetIfNeed(TaskNode* src, TaskNode* dst) {
}
}
bool IsBackEdge(TaskNode* src, TaskNode* dst) {
return src->GetTaskType() == TaskType::kNormalMdUpdt
&& (dst->GetTaskType() == TaskType::kNormalForward
|| dst->GetTaskType() == TaskType::kNormalBackward);
}
} // namespace oneflow
......@@ -17,6 +17,10 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
TaskGraph(std::unique_ptr<const LogicalGraph>&& logical_gph);
const char* TypeName() const override { return "TaskGraph"; }
void AddOrderingCtrlEdgeInSameChain();
void AddMutexCtrlEdgeInSameChain();
void AddOrderCtrlEdgeBetweenCopyAndMdUpdt();
void AcyclicTopoForEachNode(std::function<void(TaskNode* node)> handler) const;
#define DECLARE_BLD_SUB_TASK_GRAPH_METHOD(method_name) void method_name BLD_SUB_TSK_GPH_MTHD_ARGS();
......@@ -44,8 +48,14 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
std::function<int64_t(const TaskNode*)> AllocateCpuThrdId);
void ConnectWithCopyCommNetIfNeed(TaskNode* src, TaskNode* dst);
void SetAreaIdForNewNodes(const LogicalNode* src_logical, const LogicalNode* dst_logical);
void CollectAncestorsForEachNode();
void FindChainsInSameStream();
std::unique_ptr<const LogicalGraph> logical_gph_;
std::vector<TaskNode*> ordered_task_nodes_;
};
bool IsBackEdge(TaskNode* src, TaskNode* dst);
} // namespace oneflow
......
......@@ -16,7 +16,7 @@ TaskNode::TaskNode()
: machine_id_(-1),
thrd_id_(-1),
task_id_(-1),
area_id_(-1),
area_id_(0),
chain_id_(-1),
order_in_graph_(-1) {}
......@@ -57,6 +57,21 @@ void TaskNode::set_thrd_id(int64_t val) {
if (machine_id_ != -1) { UpdateTaskId(); }
}
void TaskNode::set_area_id(int64_t val) {
CHECK_EQ(area_id_, 0);
area_id_ = val;
}
void TaskNode::set_chain_id(int64_t val) {
CHECK_EQ(chain_id_, -1);
chain_id_ = val;
}
void TaskNode::set_order_in_graph(int64_t val) {
CHECK_EQ(order_in_graph_, -1);
order_in_graph_ = val;
}
void TaskNode::PinConsumedRegst() {
for (auto& pair : consumed_regsts_) {
for (std::weak_ptr<RegstDesc> regst : pair.second) {
......@@ -126,10 +141,8 @@ int64_t TaskNode::MemZoneId121() const {
}
void TaskNode::BuildCtrlRegstDescIfNeed(TaskNode* dst_node) {
for (auto& name2regst : produced_regsts_) {
const auto& consumers = name2regst.second->consumers();
if (consumers.find(dst_node) != consumers.end()) { return; }
}
const auto& dst_ancestors = dst_node->ancestors();
if (dst_ancestors.find(this) != dst_ancestors.end()) return;
RegstDescTypeProto regst_desc_type;
regst_desc_type.mutable_ctrl_regst_desc();
auto regst = NewProducedRegst(1, kMaxRegisterNum, regst_desc_type);
......
......@@ -37,6 +37,8 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
const HashMap<std::string, std::list<std::weak_ptr<RegstDesc>>>& consumed_regsts() {
return consumed_regsts_;
}
const HashSet<TaskNode*> ancestors() const { return ancestors_; }
HashSet<TaskNode*>& mut_ancestors() { return ancestors_; }
DeviceType device_type() const;
virtual const ParallelContext* parallel_ctx() const { return nullptr; }
int64_t LocalWorkStreamId() const;
......@@ -47,6 +49,9 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
// Setters
void set_machine_id(int64_t val);
void set_thrd_id(int64_t val);
void set_area_id(int64_t val);
void set_chain_id(int64_t val);
void set_order_in_graph(int64_t val);
// Build
virtual void ProduceAllRegstsAndBindEdges() = 0;
......@@ -102,6 +107,8 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
ExecGraph exec_gph_;
HashMap<std::string, std::shared_ptr<RegstDesc>> produced_regsts_;
HashMap<std::string, std::list<std::weak_ptr<RegstDesc>>> consumed_regsts_;
HashSet<TaskNode*> ancestors_;
};
class TaskEdge final : public Edge<TaskNode, TaskEdge> {
......
......@@ -59,7 +59,7 @@ Plan Compiler::DoCompile() {
task_gph->ForEachNode(std::bind(&TaskNode::Build, _1), std::bind(&TaskNode::IsReadyForBuild, _1));
task_gph->ForEachNode(std::bind(&TaskNode::EraseEmptyProducedRegst, _1));
task_gph->ForEachNode(std::bind(&TaskNode::ClearOutOfDateConsumedRegst, _1));
OrderTaskNodesInSameStream(task_gph.get());
task_gph->AddOrderingCtrlEdgeInSameChain();
Plan plan;
task_gph->ForEachNode([&](TaskNode* task_node) {
if (task_node->IsMeaningLess()) { return; }
......@@ -86,47 +86,4 @@ Plan Compiler::DoCompile() {
return plan;
}
void Compiler::OrderTaskNodesInSameStream(TaskGraph* task_gph) {
std::list<TaskNode*> starts;
task_gph->ForEachNode([&](TaskNode* node) {
if (node->consumed_regsts().empty() && !node->IsMeaningLess()) { starts.push_back(node); }
});
HashMap<int64_t, TaskNode*> stream_id2node;
auto ForEachInNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& handler) {
const auto& consumed_regsts = node->consumed_regsts();
for (const auto& name2regsts : consumed_regsts) {
for (auto& regst : name2regsts.second) {
const TaskNode* producer = regst.lock()->producer();
if (producer->GetTaskType() != TaskType::kNormalMdUpdt) {
handler(const_cast<TaskNode*>(producer));
}
}
}
};
auto ForEachOutNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& handler) {
const auto& produced_regsts = node->produced_regsts();
for (const auto& name2regst : produced_regsts) {
const auto& consumers = name2regst.second->consumers();
for (const TaskNode* consumer : consumers) {
TaskType task_type = consumer->GetTaskType();
if (task_type != TaskType::kMdDiffAcc && task_type != TaskType::kNormalMdUpdt
&& task_type != TaskType::kLossAcc && task_type != TaskType::kMdSave) {
handler(const_cast<TaskNode*>(consumer));
}
}
}
};
task_gph->TopoForEachNode(starts, ForEachInNode, ForEachOutNode, [&](TaskNode* node) {
if (Global<IDMgr>::Get()->IsIndependentLocalWorkStreamId(node->LocalWorkStreamId())) { return; }
int64_t global_stream_id = node->GlobalWorkStreamId();
auto iter = stream_id2node.find(global_stream_id);
if (iter == stream_id2node.end()) {
CHECK(stream_id2node.emplace(global_stream_id, node).second);
} else {
iter->second->BuildCtrlRegstDescIfNeed(node);
iter->second = node;
}
});
}
} // namespace oneflow
......@@ -18,7 +18,6 @@ class Compiler final {
private:
Plan DoCompile();
void OrderTaskNodesInSameStream(TaskGraph* task_gph);
};
} // namespace oneflow
......
......@@ -78,6 +78,12 @@ int64_t IDMgr::LocalWorkStreamId4ActorId(int64_t actor_id) const {
return LocalWorkStreamId4TaskId(actor_id);
}
int64_t IDMgr::AllocateChainId(int64_t global_work_stream_id) {
CHECK_LT(stream_id2chain_cnt_[global_work_stream_id],
(static_cast<int64_t>(1) << task_id_bit_num_) - 1);
return global_work_stream_id | (stream_id2chain_cnt_[global_work_stream_id]++);
}
IDMgr::IDMgr() {
const Resource& resource = Global<JobDesc>::Get()->resource();
CHECK_LT(resource.machine_size(), static_cast<int64_t>(1) << machine_id_bit_num_);
......
......@@ -56,6 +56,7 @@ class IDMgr final {
// 1 | 10 | 11 | 21 | 21
int64_t GlobalWorkStreamId4ActorId(int64_t actor_id) const;
int64_t GlobalWorkStreamId4TaskId(int64_t task_id) const;
int64_t AllocateChainId(int64_t global_work_stream_id);
private:
friend class Global<IDMgr>;
......@@ -67,6 +68,7 @@ class IDMgr final {
int64_t regst_desc_id_count_;
HashMap<int64_t, int64_t> machine_thrd_id2num_of_tasks_;
HashMap<int64_t, int64_t> machine_thrd_id2stream_id_cnt_;
HashMap<int64_t, int64_t> stream_id2chain_cnt_;
// 64 bit id design:
// sign | machine | thread | local_work_stream | task
......
......@@ -28,6 +28,17 @@ enum TaskType {
kReduceScatter = 19;
};
enum AreaType {
kInvalidArea = 0;
kDataPreprocessArea = 1;
kDataForwardArea = 2;
kDataBackwardArea = 3;
kMdUpdtArea = 4;
kMdSaveArea = 5;
kPrintArea = 6;
kBoundaryArea = 7;
}
message RegstDescIdSet {
repeated int64 regst_desc_id = 1;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册